mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-26 20:06:37 +08:00
255 lines
7.3 KiB
Go
255 lines
7.3 KiB
Go
package vertex_claude
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/pkg/errors"
|
||
"golang.org/x/oauth2"
|
||
"golang.org/x/oauth2/google"
|
||
"io"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
relaymodel "one-api/dto"
|
||
"one-api/relay/channel/claude"
|
||
"one-api/service"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
var accessTokenMap sync.Map
|
||
|
||
func getAccessToken(json string) (string, error) {
|
||
data, ok := accessTokenMap.Load(json)
|
||
if ok {
|
||
token := data.(oauth2.Token)
|
||
if time.Now().Before(token.Expiry) {
|
||
return token.AccessToken, nil
|
||
}
|
||
}
|
||
creds, err := google.CredentialsFromJSON(context.Background(), []byte(json), "https://www.googleapis.com/auth/cloud-platform")
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
token, err := creds.TokenSource.Token()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
accessTokenMap.Store(json, *token)
|
||
return token.AccessToken, nil
|
||
}
|
||
|
||
func getRedirectModel(requestModel string) (string, error) {
|
||
if model, ok := modelIdMap[requestModel]; ok {
|
||
return model, nil
|
||
}
|
||
return "", errors.Errorf("model %s not found", requestModel)
|
||
}
|
||
|
||
func requestOpenAI2VertexClaude(request dto.GeneralOpenAIRequest) (*VertexClaudeRequest, error) {
|
||
vertexClaudeRequest := VertexClaudeRequest{
|
||
AnthropicVersion: "vertex-2023-10-16",
|
||
Stream: request.Stream,
|
||
}
|
||
if vertexClaudeRequest.MaxTokens == 0 {
|
||
vertexClaudeRequest.MaxTokens = 4096
|
||
}
|
||
formatMessages := make([]dto.Message, 0)
|
||
var lastMessage *dto.Message
|
||
for i, message := range request.Messages {
|
||
if message.Role == "" {
|
||
request.Messages[i].Role = "user"
|
||
}
|
||
fmtMessage := dto.Message{
|
||
Role: message.Role,
|
||
Content: message.Content,
|
||
}
|
||
if lastMessage != nil && lastMessage.Role == message.Role {
|
||
if lastMessage.IsStringContent() && message.IsStringContent() {
|
||
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||
fmtMessage.Content = content
|
||
// delete last message
|
||
formatMessages = formatMessages[:len(formatMessages)-1]
|
||
}
|
||
}
|
||
if fmtMessage.Content == nil {
|
||
content, _ := json.Marshal("...")
|
||
fmtMessage.Content = content
|
||
}
|
||
formatMessages = append(formatMessages, fmtMessage)
|
||
lastMessage = &request.Messages[i]
|
||
}
|
||
claudeMessages := make([]claude.ClaudeMessage, 0)
|
||
for _, message := range formatMessages {
|
||
if message.Role == "system" {
|
||
if message.IsStringContent() {
|
||
vertexClaudeRequest.System = message.StringContent()
|
||
} else {
|
||
contents := message.ParseContent()
|
||
content := ""
|
||
for _, ctx := range contents {
|
||
if ctx.Type == "text" {
|
||
content += ctx.Text
|
||
}
|
||
}
|
||
vertexClaudeRequest.System = content
|
||
}
|
||
} else {
|
||
claudeMessage := claude.ClaudeMessage{
|
||
Role: message.Role,
|
||
}
|
||
if message.IsStringContent() {
|
||
claudeMessage.Content = message.StringContent()
|
||
} else {
|
||
claudeMediaMessages := make([]claude.ClaudeMediaMessage, 0)
|
||
for _, mediaMessage := range message.ParseContent() {
|
||
claudeMediaMessage := claude.ClaudeMediaMessage{
|
||
Type: mediaMessage.Type,
|
||
}
|
||
if mediaMessage.Type == "text" {
|
||
claudeMediaMessage.Text = mediaMessage.Text
|
||
} else {
|
||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||
claudeMediaMessage.Type = "image"
|
||
claudeMediaMessage.Source = &claude.ClaudeMessageSource{
|
||
Type: "base64",
|
||
}
|
||
// 判断是否是url
|
||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||
// 是url,获取图片的类型和base64编码的数据
|
||
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||
claudeMediaMessage.Source.MediaType = mimeType
|
||
claudeMediaMessage.Source.Data = data
|
||
} else {
|
||
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
claudeMediaMessage.Source.MediaType = "image/" + format
|
||
claudeMediaMessage.Source.Data = base64String
|
||
}
|
||
}
|
||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||
}
|
||
claudeMessage.Content = claudeMediaMessages
|
||
}
|
||
claudeMessages = append(claudeMessages, claudeMessage)
|
||
}
|
||
}
|
||
vertexClaudeRequest.Messages = claudeMessages
|
||
return &vertexClaudeRequest, nil
|
||
}
|
||
|
||
func vertexClaudeHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||
var claudeResponse claude.ClaudeResponse
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
err = resp.Body.Close()
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
openaiResp := claude.ResponseClaude2OpenAI(claude.RequestModeMessage, &claudeResponse)
|
||
usage := relaymodel.Usage{
|
||
PromptTokens: claudeResponse.Usage.InputTokens,
|
||
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||
}
|
||
openaiResp.Usage = usage
|
||
c.JSON(http.StatusOK, openaiResp)
|
||
return nil, &usage
|
||
}
|
||
|
||
func vertexClaudeStreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||
if atEOF && len(data) == 0 {
|
||
return 0, nil, nil
|
||
}
|
||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||
return i + 1, data[0:i], nil
|
||
}
|
||
if atEOF {
|
||
return len(data), data, nil
|
||
}
|
||
return 0, nil, nil
|
||
})
|
||
dataChan := make(chan string)
|
||
stopChan := make(chan bool)
|
||
go func() {
|
||
for scanner.Scan() {
|
||
data := scanner.Text()
|
||
if len(data) < 5 { // ignore blank line or wrong format
|
||
continue
|
||
}
|
||
if data[:5] != "data:" {
|
||
continue
|
||
}
|
||
data = data[5:]
|
||
dataChan <- data
|
||
}
|
||
stopChan <- true
|
||
}()
|
||
var id string
|
||
var model string
|
||
createdTime := common.GetTimestamp()
|
||
var usage relaymodel.Usage
|
||
service.SetEventStreamHeaders(c)
|
||
c.Stream(func(w io.Writer) bool {
|
||
select {
|
||
case data := <-dataChan:
|
||
claudeResp := new(claude.ClaudeResponse)
|
||
err := json.Unmarshal([]byte(data), &claudeResp)
|
||
if err != nil {
|
||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||
return true
|
||
}
|
||
response, claudeUsage := claude.StreamResponseClaude2OpenAI(claude.RequestModeMessage, claudeResp)
|
||
|
||
if claudeUsage != nil {
|
||
usage.PromptTokens += claudeUsage.InputTokens
|
||
usage.CompletionTokens += claudeUsage.OutputTokens
|
||
}
|
||
|
||
if response == nil {
|
||
return true
|
||
}
|
||
|
||
if response.Id != "" {
|
||
id = response.Id
|
||
}
|
||
if response.Model != "" {
|
||
model = response.Model
|
||
}
|
||
response.Created = createdTime
|
||
response.Id = id
|
||
response.Model = model
|
||
|
||
jsonStr, err := json.Marshal(response)
|
||
if err != nil {
|
||
common.SysError("error marshalling stream response: " + err.Error())
|
||
return true
|
||
}
|
||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||
return true
|
||
case <-stopChan:
|
||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||
return false
|
||
}
|
||
})
|
||
err := resp.Body.Close()
|
||
if err != nil {
|
||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||
}
|
||
return nil, &usage
|
||
}
|