Merge pull request #103 from Calcium-Ion/dev

feat: support Claude 3
This commit is contained in:
Calcium-Ion 2024-03-08 19:47:24 +08:00 committed by GitHub
commit eca48268b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 260 additions and 76 deletions

View File

@ -12,7 +12,7 @@ import (
"strings" "strings"
) )
func DecodeBase64ImageData(base64String string) (image.Config, string, error) { func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
// 去除base64数据的URL前缀如果有 // 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 { if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:] base64String = base64String[idx+1:]
@ -22,13 +22,13 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String) decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil { if err != nil {
fmt.Println("Error: Failed to decode base64 string") fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, "", err return image.Config{}, "", "", err
} }
// 创建一个bytes.Buffer用于存储解码后的数据 // 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData) reader := bytes.NewReader(decodedData)
config, format, err := getImageConfig(reader) config, format, err := getImageConfig(reader)
return config, format, err return config, format, base64String, err
} }
func IsImageUrl(url string) (bool, error) { func IsImageUrl(url string) (bool, error) {
@ -42,6 +42,7 @@ func IsImageUrl(url string) (bool, error) {
return true, nil return true, nil
} }
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) { func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url) isImage, err := IsImageUrl(url)
if !isImage { if !isImage {

View File

@ -82,6 +82,14 @@ func (m Message) StringContent() string {
return string(m.Content) return string(m.Content)
} }
func (m Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return true
}
return false
}
func (m Message) ParseContent() []MediaMessage { func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage var contentList []MediaMessage
var stringContent string var stringContent string
@ -130,9 +138,3 @@ func (m Message) ParseContent() []MediaMessage {
return nil return nil
} }
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@ -61,3 +61,9 @@ type CompletionsStreamResponse struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` } `json:"choices"`
} }
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@ -6,10 +6,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
) )
@ -50,15 +50,15 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re
} }
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
common.SysLog(fmt.Sprintf("Request mode: %d", a.RequestMode))
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
//if a.RequestMode == RequestModeCompletion { if a.RequestMode == RequestModeCompletion {
// return requestOpenAI2ClaudeComplete(*request), nil return requestOpenAI2ClaudeComplete(*request), nil
//} else { } else {
// return requestOpenAI2ClaudeMessage(*request), nil return requestOpenAI2ClaudeMessage(*request)
//} }
return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
@ -67,11 +67,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
err, responseText = claudeStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -1,7 +1,7 @@
package claude package claude
var ModelList = []string{ var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229",
} }
var ChannelName = "claude" var ChannelName = "claude"

View File

@ -4,10 +4,32 @@ type ClaudeMetadata struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample"` System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
@ -22,8 +44,25 @@ type ClaudeError struct {
} }
type ClaudeResponse struct { type ClaudeResponse struct {
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"` Completion string `json:"completion"`
StopReason string `json:"stop_reason"` StopReason string `json:"stop_reason"`
Model string `json:"model"` Model string `json:"model"`
Error ClaudeError `json:"error"` Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} }

View File

@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
switch reason { switch reason {
case "stop_sequence": case "stop_sequence":
return "stop" return "stop"
case "end_turn":
return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
default: default:
@ -54,25 +56,108 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
return &claudeRequest return &claudeRequest
} }
//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
// claudeRequest := ClaudeRequest{
//} Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
}
claudeMessages := make([]ClaudeMessage, 0)
for _, message := range textRequest.Messages {
if message.Role == "system" {
claudeRequest.System = message.StringContent()
} else {
claudeMessage := ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { return &claudeRequest, nil
}
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" { if finishReason != "null" {
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
var response dto.ChatCompletionsStreamResponse } else {
response.Object = "chat.completion.chunk" if claudeResponse.Type == "message_start" {
response.Model = claudeResponse.Model response.Id = claudeResponse.Message.Id
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} response.Model = claudeResponse.Message.Model
return &response claudeUsage = &claudeResponse.Message.Usage
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
} }
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
}
if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
Index: 0, Index: 0,
@ -83,26 +168,39 @@ func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.OpenAITextRespon
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
fullTextResponse := dto.OpenAITextResponse{ choices = append(choices, choice)
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), } else {
Object: "chat.completion", fullTextResponse.Id = claudeResponse.Id
Created: common.GetTimestamp(), for i, message := range claudeResponse.Content {
Choices: []dto.OpenAITextResponseChoice{choice}, content, _ := json.Marshal(message.Text)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
Content: content,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
choices = append(choices, choice)
}
}
fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse
} }
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage dto.Usage
responseText := ""
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 4, data[0:i], nil return i + 1, data[0:i], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
@ -114,10 +212,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") { if !strings.HasPrefix(data, "data: ") {
continue continue
} }
data = strings.TrimPrefix(data, "event: completion\r\ndata: ") data = strings.TrimPrefix(data, "data: ")
dataChan <- data dataChan <- data
} }
stopChan <- true stopChan <- true
@ -134,10 +232,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse) responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId response.Id = responseId
response.Created = createdTime response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
@ -152,12 +271,15 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, responseText if requestMode == RequestModeCompletion {
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, &usage
} }
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@ -182,12 +304,17 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens := service.CountTokenText(claudeResponse.Completion, model) completionTokens := service.CountTokenText(claudeResponse.Completion, model)
usage := dto.Usage{ usage := dto.Usage{}
PromptTokens: promptTokens, if requestMode == RequestModeCompletion {
CompletionTokens: completionTokens, usage.PromptTokens = promptTokens
TotalTokens: promptTokens + completionTokens, usage.CompletionTokens = completionTokens
usage.TotalTokens = promptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
} }
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)

View File

@ -74,7 +74,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
config, format, err = common.DecodeUrlImageData(imageUrl.Url) config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, format, err = common.DecodeBase64ImageData(imageUrl.Url) config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 14: case 14:
localModels = ['claude-instant-1', 'claude-2']; localModels = ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229"];
break; break;
case 11: case 11:
localModels = ['PaLM-2']; localModels = ['PaLM-2'];

View File

@ -3,6 +3,8 @@ import {API, isMobile, showError, showInfo, showSuccess} from '../../helpers';
import {renderNumber, renderQuota} from '../../helpers/render'; import {renderNumber, renderQuota} from '../../helpers/render';
import {Col, Layout, Row, Typography, Card, Button, Form, Divider, Space, Modal} from "@douyinfe/semi-ui"; import {Col, Layout, Row, Typography, Card, Button, Form, Divider, Space, Modal} from "@douyinfe/semi-ui";
import Title from "@douyinfe/semi-ui/lib/es/typography/title"; import Title from "@douyinfe/semi-ui/lib/es/typography/title";
import Text from '@douyinfe/semi-ui/lib/es/typography/text';
import { Link } from 'react-router-dom';
const TopUp = () => { const TopUp = () => {
const [redemptionCode, setRedemptionCode] = useState(''); const [redemptionCode, setRedemptionCode] = useState('');
@ -290,6 +292,15 @@ const TopUp = () => {
</Space> </Space>
</Form> </Form>
</div> </div>
<div style={{ display: 'flex', justifyContent: 'right' }}>
<Text>
<Link onClick={
async () => {
window.location.href = '/topup/history'
}
}>充值记录</Link>
</Text>
</div>
</Card> </Card>
</div> </div>