mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-10 16:13:42 +08:00
feat: support vertex ai #377
This commit is contained in:
@@ -79,9 +79,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -346,7 +346,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
var usage *dto.Usage
|
||||
usage = &dto.Usage{}
|
||||
@@ -428,7 +428,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -454,15 +454,15 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if requestMode == RequestModeCompletion {
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens = completionTokens
|
||||
usage.TotalTokens = promptTokens + completionTokens
|
||||
usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
} else {
|
||||
usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
|
||||
@@ -70,9 +70,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
err, usage = GeminiChatHandler(c, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
@@ -279,7 +279,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
183
relay/channel/vertex/adaptor.go
Normal file
183
relay/channel/vertex/adaptor.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/copier"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
RequestModeClaude = 1
|
||||
RequestModeGemini = 2
|
||||
RequestModeLlama = 3
|
||||
)
|
||||
|
||||
var claudeModelMap = map[string]string{
|
||||
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
|
||||
"claude-3-opus-20240229": "claude-3-opus@20240229",
|
||||
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
|
||||
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
|
||||
}
|
||||
|
||||
const anthropicVersion = "vertex-2023-10-16"
|
||||
|
||||
type Adaptor struct {
|
||||
RequestMode int
|
||||
AccountCredentials Credentials
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
|
||||
a.RequestMode = RequestModeGemini
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
a.AccountCredentials = *adc
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if info.IsStream {
|
||||
suffix = "streamGenerateContent?alt=sse"
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
info.ApiVersion,
|
||||
adc.ProjectID,
|
||||
info.ApiVersion,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
} else {
|
||||
suffix = "rawPredict"
|
||||
}
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
info.UpstreamModelName = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
info.ApiVersion,
|
||||
adc.ProjectID,
|
||||
info.ApiVersion,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
info.ApiVersion,
|
||||
adc.ProjectID,
|
||||
info.ApiVersion,
|
||||
), nil
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if a.RequestMode == RequestModeClaude {
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vertexClaudeReq := &VertexAIClaudeRequest{
|
||||
AnthropicVersion: anthropicVersion,
|
||||
}
|
||||
if err = copier.Copy(vertexClaudeReq, claudeReq); err != nil {
|
||||
return nil, errors.New("failed to copy claude request")
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest := gemini.CovertGemini2OpenAI(*request)
|
||||
c.Set("request_model", request.Model)
|
||||
return geminiRequest, nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return request, nil
|
||||
}
|
||||
return nil, errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
}
|
||||
} else {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
15
relay/channel/vertex/constants.go
Normal file
15
relay/channel/vertex/constants.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package vertex
|
||||
|
||||
var ModelList = []string{
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
|
||||
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
|
||||
"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
|
||||
|
||||
"meta/llama3-405b-instruct-maas",
|
||||
}
|
||||
|
||||
var ChannelName = "vertex-ai"
|
||||
17
relay/channel/vertex/dto.go
Normal file
17
relay/channel/vertex/dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package vertex
|
||||
|
||||
import "one-api/relay/channel/claude"
|
||||
|
||||
type VertexAIClaudeRequest struct {
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []claude.ClaudeMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Tools []claude.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
122
relay/channel/vertex/service_account.go
Normal file
122
relay/channel/vertex/service_account.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Credentials struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ClientEmail string `json:"client_email"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
||||
RefreshDuration: time.Minute * 35,
|
||||
EnableExpire: true,
|
||||
ExpireDuration: time.Minute * 30,
|
||||
Fetcher: func(key string) (interface{}, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
})
|
||||
|
||||
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
||||
val, err := Cache.Get(cacheKey)
|
||||
if err == nil {
|
||||
return val.(string), nil
|
||||
}
|
||||
|
||||
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
newToken, err := exchangeJwtForAccessToken(signedJWT)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
|
||||
}
|
||||
if err := Cache.SetDefault(cacheKey, newToken); err {
|
||||
return newToken, nil
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
||||
|
||||
func createSignedJWT(email, privateKeyPEM string) (string, error) {
|
||||
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
|
||||
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
|
||||
|
||||
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("failed to parse PEM block containing the private key")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("not an RSA private key")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": email,
|
||||
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
||||
"aud": "https://www.googleapis.com/oauth2/v4/token",
|
||||
"exp": now.Add(time.Minute * 30).Unix(),
|
||||
"iat": now.Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signedToken, err := token.SignedString(rsaPrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
|
||||
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
resp, err := http.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
Reference in New Issue
Block a user