feat: support vertex ai #377

This commit is contained in:
CalciumIon
2024-08-27 20:19:51 +08:00
parent 46e03683ce
commit ac4262c542
18 changed files with 609 additions and 47 deletions

View File

@@ -0,0 +1,183 @@
package vertex
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"strings"
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeLlama = 3
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
info.ApiVersion,
adc.ProjectID,
info.ApiVersion,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
info.UpstreamModelName = v
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
info.ApiVersion,
adc.ProjectID,
info.ApiVersion,
info.UpstreamModelName,
suffix,
), nil
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
info.ApiVersion,
adc.ProjectID,
info.ApiVersion,
), nil
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
vertexClaudeReq := &VertexAIClaudeRequest{
AnthropicVersion: anthropicVersion,
}
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
return nil, errors.New("failed to copy claude request")
}
c.Set("request_model", request.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest := gemini.CovertGemini2OpenAI(*request)
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OaiStreamHandler(c, resp, info)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,15 @@
package vertex
var ModelList = []string{
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-sonnet-20240620",
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"meta/llama3-405b-instruct-maas",
}
var ChannelName = "vertex-ai"

View File

@@ -0,0 +1,17 @@
package vertex
import "one-api/relay/channel/claude"
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []claude.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}

View File

@@ -0,0 +1,122 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"strings"
"fmt"
"time"
)
type Credentials struct {
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
}
var Cache = asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: time.Minute * 35,
EnableExpire: true,
ExpireDuration: time.Minute * 30,
Fetcher: func(key string) (interface{}, error) {
return nil, errors.New("not found")
},
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
}
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
if err := Cache.SetDefault(cacheKey, newToken); err {
return newToken, nil
}
return newToken, nil
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 30).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}