mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-30 07:06:38 +08:00
165 lines
3.7 KiB
Go
165 lines
3.7 KiB
Go
package vertex
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/rsa"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/golang-jwt/jwt"
|
||
"github.com/songquanpeng/one-api/relay/meta"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
)
|
||
|
||
type Credentials struct {
|
||
PrivateKey string
|
||
PrivateKeyID string
|
||
ClientEmail string
|
||
}
|
||
|
||
// ServiceAccount holds the credentials and scopes required for token generation
|
||
type ServiceAccount struct {
|
||
Cred *Credentials
|
||
Scopes string
|
||
}
|
||
|
||
var scopes = "https://www.googleapis.com/auth/cloud-platform"
|
||
|
||
// createSignedJWT creates a Signed JWT from service account credentials
|
||
func (sa *ServiceAccount) createSignedJWT() (string, error) {
|
||
if sa.Cred == nil {
|
||
return "", fmt.Errorf("credentials are nil")
|
||
}
|
||
|
||
issuedAt := time.Now()
|
||
expiresAt := issuedAt.Add(time.Hour)
|
||
|
||
claims := &jwt.MapClaims{
|
||
"iss": sa.Cred.ClientEmail,
|
||
"sub": sa.Cred.ClientEmail,
|
||
"aud": "https://www.googleapis.com/oauth2/v4/token",
|
||
"iat": issuedAt.Unix(),
|
||
"exp": expiresAt.Unix(),
|
||
"scope": scopes,
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||
token.Header["kid"] = sa.Cred.PrivateKeyID
|
||
token.Header["alg"] = "RS256"
|
||
token.Header["typ"] = "JWT"
|
||
|
||
// 解析 PEM 编码的私钥
|
||
block, _ := pem.Decode([]byte(sa.Cred.PrivateKey))
|
||
if block == nil {
|
||
return "", errors.New("failed to decode PEM block containing private key")
|
||
}
|
||
|
||
// 解析 RSA 私钥
|
||
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
|
||
if !ok {
|
||
return "", errors.New("private key is not of type RSA")
|
||
}
|
||
|
||
signedToken, err := token.SignedString(rsaPrivateKey)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return signedToken, nil
|
||
}
|
||
|
||
// getToken uses the signed JWT to obtain an access token
|
||
func (sa *ServiceAccount) getToken(ctx context.Context) (string, error) {
|
||
signedJWT, err := sa.createSignedJWT()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return exchangeJwtForAccessToken(ctx, signedJWT)
|
||
}
|
||
|
||
// exchangeJwtForAccessToken exchanges a Signed JWT for a Google OAuth Access Token.
|
||
func exchangeJwtForAccessToken(ctx context.Context, signedJWT string) (string, error) {
|
||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||
params := map[string]string{
|
||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||
"assertion": signedJWT,
|
||
}
|
||
|
||
jsonData, err := json.Marshal(params)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// Create a new HTTP client with a timeout
|
||
client := &http.Client{
|
||
Timeout: time.Second * 5,
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
var data map[string]interface{}
|
||
err = json.Unmarshal(body, &data)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// Extract the access token from the response
|
||
accessToken, ok := data["access_token"].(string)
|
||
if !ok {
|
||
return "", err // You might want to return a more specific error here
|
||
}
|
||
|
||
return accessToken, nil
|
||
}
|
||
|
||
func getToken(ctx context.Context, meta *meta.Meta) (string, error) {
|
||
// todo 每次请求都要换次token???
|
||
encodedString := ""
|
||
decodedBytes, err := base64.StdEncoding.DecodeString(encodedString)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
m := make(map[string]string)
|
||
err = json.Unmarshal(decodedBytes, &m)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
sa := &ServiceAccount{
|
||
Cred: &Credentials{
|
||
PrivateKey: m["private_key"],
|
||
PrivateKeyID: m["private_key_id"],
|
||
ClientEmail: m["client_email"],
|
||
},
|
||
Scopes: scopes,
|
||
}
|
||
return sa.getToken(ctx)
|
||
}
|