mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-12 03:13:41 +08:00
🎨 调整供应商目录结构,合并文本输出函数
This commit is contained in:
124
providers/baidu/base.go
Normal file
124
providers/baidu/base.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/providers/base"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
type BaiduProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
|
||||
return &BaiduProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: "https://aip.baidubce.com",
|
||||
ChatCompletions: "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat",
|
||||
Embeddings: "/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings",
|
||||
Context: c,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
var modelNameMap = map[string]string{
|
||||
"ERNIE-Bot": "completions",
|
||||
"ERNIE-Bot-turbo": "eb-instant",
|
||||
"ERNIE-Bot-4": "completions_pro",
|
||||
"BLOOMZ-7B": "bloomz_7b1",
|
||||
"Embedding-V1": "embedding-v1",
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
apiKey, err := p.getBaiduAccessToken()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s/%s?access_token=%s", baseURL, requestURL, modelNameMap[modelName], apiKey)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduAccessToken() (string, error) {
|
||||
apiKey := p.Context.GetString("api_key")
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken BaiduAccessToken
|
||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
_, _ = p.getBaiduAccessTokenHelper(apiKey)
|
||||
}()
|
||||
}
|
||||
return accessToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
accessToken, err := p.getBaiduAccessTokenHelper(apiKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
url := fmt.Sprintf(p.BaseURL+"/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", parts[0], parts[1])
|
||||
|
||||
var headers = map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
req, err := client.NewRequest("POST", url, common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
var accessToken BaiduAccessToken
|
||||
err = json.NewDecoder(resp.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accessToken.Error != "" {
|
||||
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
|
||||
}
|
||||
if accessToken.AccessToken == "" {
|
||||
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
|
||||
}
|
||||
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
|
||||
baiduTokenStore.Store(apiKey, accessToken)
|
||||
return &accessToken, nil
|
||||
}
|
||||
Reference in New Issue
Block a user