mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 01:06:37 +08:00
74 lines
2.4 KiB
Go
74 lines
2.4 KiB
Go
package vertexai
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
credentials "cloud.google.com/go/iam/credentials/apiv1"
|
||
"cloud.google.com/go/iam/credentials/apiv1/credentialspb"
|
||
"github.com/patrickmn/go-cache"
|
||
"google.golang.org/api/option"
|
||
)
|
||
|
||
type ApplicationDefaultCredentials struct {
|
||
Type string `json:"type"`
|
||
ProjectID string `json:"project_id"`
|
||
PrivateKeyID string `json:"private_key_id"`
|
||
PrivateKey string `json:"private_key"`
|
||
ClientEmail string `json:"client_email"`
|
||
ClientID string `json:"client_id"`
|
||
AuthURI string `json:"auth_uri"`
|
||
TokenURI string `json:"token_uri"`
|
||
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
|
||
ClientX509CertURL string `json:"client_x509_cert_url"`
|
||
UniverseDomain string `json:"universe_domain"`
|
||
}
|
||
|
||
var Cache = cache.New(50*time.Minute, 55*time.Minute)
|
||
|
||
const defaultScope = "https://www.googleapis.com/auth/cloud-platform"
|
||
|
||
func getToken(ctx context.Context, channelId int, adcJson string) (string, error) {
|
||
cacheKey := fmt.Sprintf("vertexai-token-%d", channelId)
|
||
if token, found := Cache.Get(cacheKey); found {
|
||
return token.(string), nil
|
||
}
|
||
adc := &ApplicationDefaultCredentials{}
|
||
if err := json.Unmarshal([]byte(adcJson), adc); err != nil {
|
||
return "", fmt.Errorf("Failed to decode credentials file: %w", err)
|
||
}
|
||
|
||
c, err := credentials.NewIamCredentialsClient(ctx, option.WithCredentialsJSON([]byte(adcJson)))
|
||
if err != nil {
|
||
return "", fmt.Errorf("Failed to create client: %w", err)
|
||
}
|
||
defer c.Close()
|
||
|
||
req := &credentialspb.GenerateAccessTokenRequest{
|
||
// See https://pkg.go.dev/cloud.google.com/go/iam/credentials/apiv1/credentialspb#GenerateAccessTokenRequest.
|
||
Name: fmt.Sprintf("projects/-/serviceAccounts/%s", adc.ClientEmail),
|
||
Scope: []string{defaultScope},
|
||
}
|
||
resp, err := c.GenerateAccessToken(ctx, req)
|
||
if err != nil {
|
||
return "", fmt.Errorf("Failed to generate access token: %w", err)
|
||
}
|
||
_ = resp
|
||
|
||
// 使用响应中的过期时间,如果获取失败则使用默认时间
|
||
expiration := cache.DefaultExpiration
|
||
if resp.ExpireTime != nil {
|
||
_expiration := time.Until(resp.ExpireTime.AsTime())
|
||
// 提前10分钟过期,避免边界情况
|
||
_expiration = _expiration - 10*time.Minute
|
||
if _expiration > 0 {
|
||
expiration = _expiration
|
||
}
|
||
}
|
||
|
||
Cache.Set(cacheKey, resp.AccessToken, expiration)
|
||
return resp.AccessToken, nil
|
||
}
|