new-api/relay/channel/vertex-claude/relay-vertex-claude.go
2024-06-23 20:21:42 +08:00

255 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package vertex_claude
import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaymodel "one-api/dto"
"one-api/relay/channel/claude"
"one-api/service"
"strings"
"sync"
"time"
)
var accessTokenMap sync.Map
func getAccessToken(json string) (string, error) {
data, ok := accessTokenMap.Load(json)
if ok {
token := data.(oauth2.Token)
if time.Now().Before(token.Expiry) {
return token.AccessToken, nil
}
}
creds, err := google.CredentialsFromJSON(context.Background(), []byte(json), "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return "", err
}
token, err := creds.TokenSource.Token()
if err != nil {
return "", err
}
accessTokenMap.Store(json, *token)
return token.AccessToken, nil
}
func getRedirectModel(requestModel string) (string, error) {
if model, ok := modelIdMap[requestModel]; ok {
return model, nil
}
return "", errors.Errorf("model %s not found", requestModel)
}
func requestOpenAI2VertexClaude(request dto.GeneralOpenAIRequest) (*VertexClaudeRequest, error) {
vertexClaudeRequest := VertexClaudeRequest{
AnthropicVersion: "vertex-2023-10-16",
Stream: request.Stream,
}
if vertexClaudeRequest.MaxTokens == 0 {
vertexClaudeRequest.MaxTokens = 4096
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range request.Messages {
if message.Role == "" {
request.Messages[i].Role = "user"
}
fmtMessage := dto.Message{
Role: message.Role,
Content: message.Content,
}
if lastMessage != nil && lastMessage.Role == message.Role {
if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
content, _ := json.Marshal("...")
fmtMessage.Content = content
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = &request.Messages[i]
}
claudeMessages := make([]claude.ClaudeMessage, 0)
for _, message := range formatMessages {
if message.Role == "system" {
if message.IsStringContent() {
vertexClaudeRequest.System = message.StringContent()
} else {
contents := message.ParseContent()
content := ""
for _, ctx := range contents {
if ctx.Type == "text" {
content += ctx.Text
}
}
vertexClaudeRequest.System = content
}
} else {
claudeMessage := claude.ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]claude.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := claude.ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &claude.ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
vertexClaudeRequest.Messages = claudeMessages
return &vertexClaudeRequest, nil
}
func vertexClaudeHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var claudeResponse claude.ClaudeResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
openaiResp := claude.ResponseClaude2OpenAI(claude.RequestModeMessage, &claudeResponse)
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
}
func vertexClaudeStreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
var id string
var model string
createdTime := common.GetTimestamp()
var usage relaymodel.Usage
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
claudeResp := new(claude.ClaudeResponse)
err := json.Unmarshal([]byte(data), &claudeResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(claude.RequestModeMessage, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}