feat: support Claude 3

This commit is contained in:
CaIon 2024-03-08 19:43:33 +08:00
parent c2965eb835
commit 4a0af1ea3c
5 changed files with 81 additions and 33 deletions

View File

@ -10,7 +10,6 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"strings" "strings"
) )
@ -68,9 +67,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
err, responseText = claudeStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -1,7 +1,7 @@
package claude package claude
var ModelList = []string{ var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229",
} }
var ChannelName = "claude" var ChannelName = "claude"

View File

@ -5,9 +5,11 @@ type ClaudeMetadata struct {
} }
type ClaudeMediaMessage struct { type ClaudeMediaMessage struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"` Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
} }
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
@ -50,8 +52,16 @@ type ClaudeResponse struct {
Model string `json:"model"` Model string `json:"model"`
Error ClaudeError `json:"error"` Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"` Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
} }
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`

View File

@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
switch reason { switch reason {
case "stop_sequence": case "stop_sequence":
return "stop" return "stop"
case "end_turn":
return "stop"
case "max_tokens": case "max_tokens":
return "length" return "length"
default: default:
@ -111,24 +113,41 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
} }
claudeRequest.Prompt = "" claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages claudeRequest.Messages = claudeMessages
reqJson, _ := json.Marshal(claudeRequest)
common.SysLog(fmt.Sprintf("claude request: %s", reqJson))
return &claudeRequest, nil return &claudeRequest, nil
} }
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse { func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice} response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
return &response var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
} else {
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
} }
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@ -170,17 +189,18 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse return &fullTextResponse
} }
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage dto.Usage
responseText := ""
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 4, data[0:i], nil return i + 1, data[0:i], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
@ -192,10 +212,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") { if !strings.HasPrefix(data, "data: ") {
continue continue
} }
data = strings.TrimPrefix(data, "event: completion\r\ndata: ") data = strings.TrimPrefix(data, "data: ")
dataChan <- data dataChan <- data
} }
stopChan <- true stopChan <- true
@ -212,10 +232,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse) response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId response.Id = responseId
response.Created = createdTime response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
@ -230,9 +271,12 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, responseText if requestMode == RequestModeCompletion {
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, &usage
} }
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@ -245,13 +289,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var claudeResponse ClaudeResponse var claudeResponse ClaudeResponse
common.SysLog(fmt.Sprintf("claude response: %s", responseBody))
err = json.Unmarshal(responseBody, &claudeResponse) err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
respJson, _ := json.Marshal(claudeResponse)
common.SysLog(fmt.Sprintf("claude response json: %s", respJson))
if claudeResponse.Error.Type != "" { if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{ Error: dto.OpenAIError{

View File

@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = []; let localModels = [];
switch (value) { switch (value) {
case 14: case 14:
localModels = ['claude-instant-1', 'claude-2']; localModels = ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229"];
break; break;
case 11: case 11:
localModels = ['PaLM-2']; localModels = ['PaLM-2'];