one-api/relay/adaptor/vertex/token.go
2024-04-25 13:41:27 +08:00

165 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package vertex
import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/relay/meta"
"io"
"net/http"
"time"
)
type Credentials struct {
PrivateKey string
PrivateKeyID string
ClientEmail string
}
// ServiceAccount holds the credentials and scopes required for token generation
type ServiceAccount struct {
Cred *Credentials
Scopes string
}
var scopes = "https://www.googleapis.com/auth/cloud-platform"
// createSignedJWT creates a Signed JWT from service account credentials
func (sa *ServiceAccount) createSignedJWT() (string, error) {
if sa.Cred == nil {
return "", fmt.Errorf("credentials are nil")
}
issuedAt := time.Now()
expiresAt := issuedAt.Add(time.Hour)
claims := &jwt.MapClaims{
"iss": sa.Cred.ClientEmail,
"sub": sa.Cred.ClientEmail,
"aud": "https://www.googleapis.com/oauth2/v4/token",
"iat": issuedAt.Unix(),
"exp": expiresAt.Unix(),
"scope": scopes,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = sa.Cred.PrivateKeyID
token.Header["alg"] = "RS256"
token.Header["typ"] = "JWT"
// 解析 PEM 编码的私钥
block, _ := pem.Decode([]byte(sa.Cred.PrivateKey))
if block == nil {
return "", errors.New("failed to decode PEM block containing private key")
}
// 解析 RSA 私钥
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", errors.New("private key is not of type RSA")
}
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
// getToken uses the signed JWT to obtain an access token
func (sa *ServiceAccount) getToken(ctx context.Context) (string, error) {
signedJWT, err := sa.createSignedJWT()
if err != nil {
return "", err
}
return exchangeJwtForAccessToken(ctx, signedJWT)
}
// exchangeJwtForAccessToken exchanges a Signed JWT for a Google OAuth Access Token.
func exchangeJwtForAccessToken(ctx context.Context, signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
params := map[string]string{
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": signedJWT,
}
jsonData, err := json.Marshal(params)
if err != nil {
return "", err
}
// Create a new HTTP client with a timeout
client := &http.Client{
Timeout: time.Second * 5,
}
req, err := http.NewRequestWithContext(ctx, "POST", authURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var data map[string]interface{}
err = json.Unmarshal(body, &data)
if err != nil {
return "", err
}
// Extract the access token from the response
accessToken, ok := data["access_token"].(string)
if !ok {
return "", err // You might want to return a more specific error here
}
return accessToken, nil
}
func getToken(ctx context.Context, meta *meta.Meta) (string, error) {
// todo 每次请求都要换次token
encodedString := ""
decodedBytes, err := base64.StdEncoding.DecodeString(encodedString)
if err != nil {
return "", err
}
m := make(map[string]string)
err = json.Unmarshal(decodedBytes, &m)
if err != nil {
return "", err
}
sa := &ServiceAccount{
Cred: &Credentials{
PrivateKey: m["private_key"],
PrivateKeyID: m["private_key_id"],
ClientEmail: m["client_email"],
},
Scopes: scopes,
}
return sa.getToken(ctx)
}