mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
123 lines
3.3 KiB
Go
123 lines
3.3 KiB
Go
package vertex
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"github.com/bytedance/gopkg/cache/asynccache"
|
|
"github.com/golang-jwt/jwt"
|
|
"net/http"
|
|
"net/url"
|
|
relaycommon "one-api/relay/common"
|
|
"strings"
|
|
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
type Credentials struct {
|
|
ProjectID string `json:"project_id"`
|
|
PrivateKeyID string `json:"private_key_id"`
|
|
PrivateKey string `json:"private_key"`
|
|
ClientEmail string `json:"client_email"`
|
|
ClientID string `json:"client_id"`
|
|
}
|
|
|
|
var Cache = asynccache.NewAsyncCache(asynccache.Options{
|
|
RefreshDuration: time.Minute * 35,
|
|
EnableExpire: true,
|
|
ExpireDuration: time.Minute * 30,
|
|
Fetcher: func(key string) (interface{}, error) {
|
|
return nil, errors.New("not found")
|
|
},
|
|
})
|
|
|
|
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
|
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
|
|
val, err := Cache.Get(cacheKey)
|
|
if err == nil {
|
|
return val.(string), nil
|
|
}
|
|
|
|
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
|
}
|
|
newToken, err := exchangeJwtForAccessToken(signedJWT)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
|
|
}
|
|
if err := Cache.SetDefault(cacheKey, newToken); err {
|
|
return newToken, nil
|
|
}
|
|
return newToken, nil
|
|
}
|
|
|
|
func createSignedJWT(email, privateKeyPEM string) (string, error) {
|
|
|
|
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
|
|
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
|
|
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
|
|
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
|
|
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
|
|
|
|
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
|
|
if block == nil {
|
|
return "", fmt.Errorf("failed to parse PEM block containing the private key")
|
|
}
|
|
|
|
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return "", fmt.Errorf("not an RSA private key")
|
|
}
|
|
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"iss": email,
|
|
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
|
"aud": "https://www.googleapis.com/oauth2/v4/token",
|
|
"exp": now.Add(time.Minute * 35).Unix(),
|
|
"iat": now.Unix(),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
signedToken, err := token.SignedString(rsaPrivateKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return signedToken, nil
|
|
}
|
|
|
|
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
|
|
|
|
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
|
data := url.Values{}
|
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
|
data.Set("assertion", signedJWT)
|
|
|
|
resp, err := http.PostForm(authURL, data)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var result map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if accessToken, ok := result["access_token"].(string); ok {
|
|
return accessToken, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("failed to get access token: %v", result)
|
|
}
|