mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
feat: support aws bedrockruntime claude3
closes #622, closes #749, closes #1300
This commit is contained in:
parent
b638a2fcbd
commit
4e1bfe4879
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -65,9 +66,9 @@ var EmailDomainWhitelist = []string{
|
|||||||
"foxmail.com",
|
"foxmail.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
var DebugEnabled = os.Getenv("DEBUG") == "true"
|
var DebugEnabled = strings.ToLower(os.Getenv("DEBUG")) == "true"
|
||||||
var DebugSQLEnabled = os.Getenv("DEBUG_SQL") == "true"
|
var DebugSQLEnabled = strings.ToLower(os.Getenv("DEBUG_SQL")) == "true"
|
||||||
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
var MemoryCacheEnabled = strings.ToLower(os.Getenv("MEMORY_CACHE_ENABLED")) == "true"
|
||||||
|
|
||||||
var LogConsumeEnabled = true
|
var LogConsumeEnabled = true
|
||||||
|
|
||||||
|
@ -4,3 +4,11 @@ import "time"
|
|||||||
|
|
||||||
var StartTime = time.Now().Unix() // unit: second
|
var StartTime = time.Now().Unix() // unit: second
|
||||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||||
|
|
||||||
|
var (
|
||||||
|
// CtxKeyChannel is the key to store the channel in the context
|
||||||
|
CtxKeyChannel string = "channel_docu"
|
||||||
|
CtxKeyRequestModel string = "request_model"
|
||||||
|
CtxKeyRawRequest string = "raw_request"
|
||||||
|
CtxKeyConvertedRequest string = "converted_request"
|
||||||
|
)
|
||||||
|
15
go.mod
15
go.mod
@ -5,6 +5,9 @@ go 1.21
|
|||||||
require (
|
require (
|
||||||
github.com/Laisky/errors/v2 v2.0.1
|
github.com/Laisky/errors/v2 v2.0.1
|
||||||
github.com/Laisky/go-utils/v4 v4.9.1
|
github.com/Laisky/go-utils/v4 v4.9.1
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.27.11
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||||
github.com/gin-contrib/cors v1.7.0
|
github.com/gin-contrib/cors v1.7.0
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
github.com/gin-contrib/sessions v0.0.5
|
github.com/gin-contrib/sessions v0.0.5
|
||||||
@ -33,6 +36,18 @@ require (
|
|||||||
github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be // indirect
|
github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be // indirect
|
||||||
github.com/Laisky/graphql v1.0.6 // indirect
|
github.com/Laisky/graphql v1.0.6 // indirect
|
||||||
github.com/Laisky/zap v1.27.0 // indirect
|
github.com/Laisky/zap v1.27.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 // indirect
|
||||||
|
github.com/aws/smithy-go v1.20.2 // indirect
|
||||||
github.com/bytedance/sonic v1.11.2 // indirect
|
github.com/bytedance/sonic v1.11.2 // indirect
|
||||||
github.com/cespare/xxhash v1.1.0 // indirect
|
github.com/cespare/xxhash v1.1.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
|
30
go.sum
30
go.sum
@ -16,6 +16,36 @@ github.com/Laisky/zap v1.27.0 h1:NEtFniRXOKUkEf9//FUiBkSIdgbB4V1H3khA/jmPZx4=
|
|||||||
github.com/Laisky/zap v1.27.0/go.mod h1:HABqM5YDQlPq8w+Pmp9h/x9F6Vy+3oHBLP+2+pBoaJw=
|
github.com/Laisky/zap v1.27.0/go.mod h1:HABqM5YDQlPq8w+Pmp9h/x9F6Vy+3oHBLP+2+pBoaJw=
|
||||||
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
|
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
|
||||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA=
|
||||||
|
github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4=
|
||||||
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw=
|
||||||
|
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||||
|
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||||
github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8=
|
github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8=
|
||||||
github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8=
|
github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8=
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
@ -73,6 +74,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
}
|
}
|
||||||
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
|
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
|
||||||
c.Set("channel_ratio", minimalRatio)
|
c.Set("channel_ratio", minimalRatio)
|
||||||
|
c.Set(common.CtxKeyChannel, channel)
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
|
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/aws"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
|
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
@ -19,6 +20,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
|
|||||||
// return &ali.Adaptor{}
|
// return &ali.Adaptor{}
|
||||||
case apitype.Anthropic:
|
case apitype.Anthropic:
|
||||||
return &anthropic.Adaptor{}
|
return &anthropic.Adaptor{}
|
||||||
|
case apitype.AwsClaude:
|
||||||
|
return &aws.Adaptor{}
|
||||||
// case apitype.Baidu:
|
// case apitype.Baidu:
|
||||||
// return &baidu.Adaptor{}
|
// return &baidu.Adaptor{}
|
||||||
case apitype.Gemini:
|
case apitype.Gemini:
|
||||||
|
@ -92,7 +92,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||||
func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||||
var response *Response
|
var response *Response
|
||||||
var responseText string
|
var responseText string
|
||||||
var stopReason string
|
var stopReason string
|
||||||
@ -130,7 +130,7 @@ func streamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
return &openaiResponse, response
|
return &openaiResponse, response
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||||
var responseText string
|
var responseText string
|
||||||
if len(claudeResponse.Content) > 0 {
|
if len(claudeResponse.Content) > 0 {
|
||||||
responseText = claudeResponse.Content[0].Text
|
responseText = claudeResponse.Content[0].Text
|
||||||
@ -201,7 +201,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
response, meta := streamResponseClaude2OpenAI(&claudeResponse)
|
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
|
||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
@ -256,7 +256,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
|
||||||
fullTextResponse.Model = modelName
|
fullTextResponse.Model = modelName
|
||||||
usage := model.Usage{
|
usage := model.Usage{
|
||||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||||
|
78
relay/adaptor/aws/adapter.go
Normal file
78
relay/adaptor/aws/adapter.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/Laisky/errors/v2"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ adaptor.Adaptor = new(Adaptor)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeReq := anthropic.ConvertRequest(*request)
|
||||||
|
c.Set(common.CtxKeyRequestModel, request.Model)
|
||||||
|
c.Set(common.CtxKeyRawRequest, request)
|
||||||
|
c.Set(common.CtxKeyConvertedRequest, claudeReq)
|
||||||
|
return claudeReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
|
if meta.IsStream {
|
||||||
|
err, usage = StreamHandler(c, resp)
|
||||||
|
} else {
|
||||||
|
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||||
|
var ModelList = []string{
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
"claude-3-sonnet-20240229",
|
||||||
|
"claude-3-opus-20240229",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return "aws"
|
||||||
|
}
|
233
relay/adaptor/aws/main.go
Normal file
233
relay/adaptor/aws/main.go
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
// Package aws provides the AWS adaptor for the relay service.
|
||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAwsClient(channel *model.Channel) (*bedrockruntime.Client, error) {
|
||||||
|
ks := strings.Split(channel.Key, "\n")
|
||||||
|
if len(ks) != 2 {
|
||||||
|
return nil, errors.New("invalid key")
|
||||||
|
}
|
||||||
|
ak, sk := ks[0], ks[1]
|
||||||
|
|
||||||
|
client := bedrockruntime.New(bedrockruntime.Options{
|
||||||
|
Region: *channel.BaseURL,
|
||||||
|
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||||
|
})
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapErr(err error) *relaymodel.ErrorWithStatusCode {
|
||||||
|
return &relaymodel.ErrorWithStatusCode{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Error: relaymodel.Error{
|
||||||
|
Message: fmt.Sprintf("%+v", err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||||
|
func awsModelID(requestModel string) (string, error) {
|
||||||
|
switch requestModel {
|
||||||
|
case "claude-instant-1.2":
|
||||||
|
return "anthropic.claude-instant-v1", nil
|
||||||
|
case "claude-2.0":
|
||||||
|
return "anthropic.claude-v2", nil
|
||||||
|
case "claude-2.1":
|
||||||
|
return "anthropic.claude-v2:1", nil
|
||||||
|
case "claude-3-sonnet-20240229":
|
||||||
|
return "anthropic.claude-3-sonnet-20240229-v1:0", nil
|
||||||
|
case "claude-3-opus-20240229":
|
||||||
|
return "anthropic.claude-3-opus-20240229-v1:0", nil
|
||||||
|
case "claude-3-haiku-20240307":
|
||||||
|
return "anthropic.claude-3-haiku-20240307-v1:0", nil
|
||||||
|
default:
|
||||||
|
return "", errors.Errorf("unknown model: %s", requestModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
|
var channel *model.Channel
|
||||||
|
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
|
||||||
|
return wrapErr(errors.New("channel not found")), nil
|
||||||
|
} else {
|
||||||
|
channel = channeli.(*model.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
awsCli, err := newAwsClient(channel)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsModelId, err := awsModelID(channel.Models)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
|
||||||
|
if !ok {
|
||||||
|
return wrapErr(errors.New("request not found")), nil
|
||||||
|
}
|
||||||
|
claudeReq := claudeReqi.(*anthropic.Request)
|
||||||
|
awsClaudeReq := &Request{
|
||||||
|
AnthropicVersion: "bedrock-2023-05-31",
|
||||||
|
}
|
||||||
|
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeResponse := new(anthropic.Response)
|
||||||
|
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
|
||||||
|
openaiResp.Model = modelName
|
||||||
|
usage := relaymodel.Usage{
|
||||||
|
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||||
|
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||||
|
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
openaiResp.Usage = usage
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, openaiResp)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
|
createdTime := helper.GetTimestamp()
|
||||||
|
|
||||||
|
var channel *model.Channel
|
||||||
|
if channeli, ok := c.Get(common.CtxKeyChannel); !ok {
|
||||||
|
return wrapErr(errors.New("channel not found")), nil
|
||||||
|
} else {
|
||||||
|
channel = channeli.(*model.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
awsCli, err := newAwsClient(channel)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsModelId, err := awsModelID(channel.Models)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeReqi, ok := c.Get(common.CtxKeyConvertedRequest)
|
||||||
|
if !ok {
|
||||||
|
return wrapErr(errors.New("request not found")), nil
|
||||||
|
}
|
||||||
|
claudeReq := claudeReqi.(*anthropic.Request)
|
||||||
|
|
||||||
|
awsClaudeReq := &Request{
|
||||||
|
AnthropicVersion: "bedrock-2023-05-31",
|
||||||
|
}
|
||||||
|
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||||
|
}
|
||||||
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||||
|
}
|
||||||
|
stream := awsResp.GetStream()
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
var usage relaymodel.Usage
|
||||||
|
var id string
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
event, ok := <-stream.Events()
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := event.(type) {
|
||||||
|
case *types.ResponseStreamMemberChunk:
|
||||||
|
claudeResp := new(anthropic.StreamResponse)
|
||||||
|
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
|
||||||
|
if meta != nil {
|
||||||
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
|
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if response == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
response.Id = id
|
||||||
|
response.Model = c.GetString("original_model")
|
||||||
|
response.Created = createdTime
|
||||||
|
jsonStr, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
return true
|
||||||
|
case *types.UnknownUnionMember:
|
||||||
|
fmt.Println("unknown tag:", v.Tag)
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
fmt.Println("union is nil or unknown type")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, &usage
|
||||||
|
}
|
17
relay/adaptor/aws/model.go
Normal file
17
relay/adaptor/aws/model.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
|
||||||
|
// Request is the request to AWS Claude
|
||||||
|
//
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||||
|
type Request struct {
|
||||||
|
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||||
|
AnthropicVersion string `json:"anthropic_version"`
|
||||||
|
Messages []anthropic.Message `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
}
|
@ -12,6 +12,7 @@ const (
|
|||||||
Tencent
|
Tencent
|
||||||
Gemini
|
Gemini
|
||||||
Ollama
|
Ollama
|
||||||
|
AwsClaude
|
||||||
|
|
||||||
Dummy // this one is only for count, do not add any channel after this
|
Dummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
|
@ -34,6 +34,7 @@ const (
|
|||||||
Ollama
|
Ollama
|
||||||
LingYiWanWu
|
LingYiWanWu
|
||||||
StepFun
|
StepFun
|
||||||
|
AwsClaude
|
||||||
|
|
||||||
Dummy
|
Dummy
|
||||||
)
|
)
|
||||||
|
@ -25,6 +25,9 @@ func ToAPIType(channelType int) int {
|
|||||||
apiType = apitype.Gemini
|
apiType = apitype.Gemini
|
||||||
case Ollama:
|
case Ollama:
|
||||||
apiType = apitype.Ollama
|
apiType = apitype.Ollama
|
||||||
|
case AwsClaude:
|
||||||
|
apiType = apitype.AwsClaude
|
||||||
}
|
}
|
||||||
|
|
||||||
return apiType
|
return apiType
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"http://localhost:11434", // 30
|
"http://localhost:11434", // 30
|
||||||
"https://api.lingyiwanwu.com", // 31
|
"https://api.lingyiwanwu.com", // 31
|
||||||
"https://api.stepfun.com", // 32
|
"https://api.stepfun.com", // 32
|
||||||
|
"", // 33
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -109,9 +109,10 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := model.PostConsumeTokenQuota(meta.TokenId, quota)
|
err := model.PostConsumeTokenQuota(meta.TokenId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
@ -94,6 +94,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if resp != nil {
|
||||||
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
|
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
|
||||||
if errorHappened {
|
if errorHappened {
|
||||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||||
@ -102,6 +103,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
|||||||
return RelayErrorHandler(resp)
|
return RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
|
}
|
||||||
|
|
||||||
// do response
|
// do response
|
||||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "react-scripts start",
|
"start": "react-scripts start",
|
||||||
"build": "react-scripts build && mv -f build ../build/default",
|
"build": "react-scripts build && rm -rf ../build/default && mv -f build ../build/default",
|
||||||
"test": "react-scripts test",
|
"test": "react-scripts test",
|
||||||
"eject": "react-scripts eject"
|
"eject": "react-scripts eject"
|
||||||
},
|
},
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
export const CHANNEL_OPTIONS = [
|
export const CHANNEL_OPTIONS = [
|
||||||
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
{ key: 1, text: 'OpenAI', value: 1, color: 'green' },
|
||||||
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
|
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black' },
|
||||||
|
{ key: 33, text: 'AWS Claude', value: 33, color: 'black' },
|
||||||
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
|
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
|
||||||
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
|
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
|
||||||
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
|
{ key: 24, text: 'Google Gemini', value: 24, color: 'orange' },
|
||||||
|
@ -144,6 +144,13 @@ const EditChannel = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const submit = async () => {
|
const submit = async () => {
|
||||||
|
// some provider as AWS need both AK and SK rather than a single key,
|
||||||
|
// so we need to combine them into a single key to achieve the best compatibility.
|
||||||
|
if (inputs.ak && inputs.sk) {
|
||||||
|
console.log(`combine ak ${inputs.ak} and sk ${inputs.sk}`, inputs.ak, inputs.sk);
|
||||||
|
inputs.key = `${inputs.ak}\n${inputs.sk}`;
|
||||||
|
}
|
||||||
|
|
||||||
if (!isEdit && (inputs.name === '' || inputs.key === '')) {
|
if (!isEdit && (inputs.name === '' || inputs.key === '')) {
|
||||||
showInfo('请填写渠道名称和渠道密钥!');
|
showInfo('请填写渠道名称和渠道密钥!');
|
||||||
return;
|
return;
|
||||||
@ -392,7 +399,40 @@ const EditChannel = () => {
|
|||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
{
|
{
|
||||||
batch ? <Form.Field>
|
inputs.type === 33 && (
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='Region'
|
||||||
|
name='base_url'
|
||||||
|
required
|
||||||
|
placeholder={'region,e.g. us-west-2'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.base_url}
|
||||||
|
autoComplete=''
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
label='AK'
|
||||||
|
name='ak'
|
||||||
|
required
|
||||||
|
placeholder={'AWS IAM Access Key'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.ak}
|
||||||
|
autoComplete=''
|
||||||
|
/>
|
||||||
|
<Form.Input
|
||||||
|
label='SK'
|
||||||
|
name='sk'
|
||||||
|
required
|
||||||
|
placeholder={'AWS IAM Secret Key'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.sk}
|
||||||
|
autoComplete=''
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
inputs.type !== 33 && (batch ? <Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
@ -413,10 +453,10 @@ const EditChannel = () => {
|
|||||||
value={inputs.key}
|
value={inputs.key}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
!isEdit && (
|
inputs.type !== 33 && !isEdit && (
|
||||||
<Form.Checkbox
|
<Form.Checkbox
|
||||||
checked={batch}
|
checked={batch}
|
||||||
label='批量创建'
|
label='批量创建'
|
||||||
@ -426,7 +466,7 @@ const EditChannel = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
inputs.type !== 3 && inputs.type !== 33 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='代理'
|
label='代理'
|
||||||
|
Loading…
Reference in New Issue
Block a user