feat: support Claude 3

This commit is contained in:
CaIon 2024-03-08 19:43:33 +08:00
parent c2965eb835
commit 4a0af1ea3c
5 changed files with 81 additions and 33 deletions

View File

@ -10,7 +10,6 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
@ -68,9 +67,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
var responseText string
err, responseText = claudeStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else {
err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
}

View File

@ -1,7 +1,7 @@
package claude
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229",
}
var ChannelName = "claude"

View File

@ -8,6 +8,8 @@ type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
}
type ClaudeMessageSource struct {
@ -50,8 +52,16 @@ type ClaudeResponse struct {
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`

View File

@ -17,6 +17,8 @@ func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "length"
default:
@ -111,24 +113,41 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
reqJson, _ := json.Marshal(claudeRequest)
common.SysLog(fmt.Sprintf("claude request: %s", reqJson))
return &claudeRequest, nil
}
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
} else {
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
} else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
}
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
}
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
@ -170,17 +189,18 @@ func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse
}
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
responseText := ""
func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage dto.Usage
responseText := ""
createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
return i + 4, data[0:i], nil
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
@ -192,10 +212,10 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
go func() {
for scanner.Scan() {
data := scanner.Text()
if !strings.HasPrefix(data, "event: completion") {
if !strings.HasPrefix(data, "data: ") {
continue
}
data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
data = strings.TrimPrefix(data, "data: ")
dataChan <- data
}
stopChan <- true
@ -212,10 +232,31 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
response := streamResponseClaude2OpenAI(&claudeResponse)
responseId = response.Id
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
modelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else {
return true
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = modelName
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
@ -230,9 +271,12 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, responseText
if requestMode == RequestModeCompletion {
usage = *service.ResponseText2Usage(responseText, modelName, promptTokens)
}
return nil, &usage
}
func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@ -245,13 +289,10 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var claudeResponse ClaudeResponse
common.SysLog(fmt.Sprintf("claude response: %s", responseBody))
err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
respJson, _ := json.Marshal(claudeResponse)
common.SysLog(fmt.Sprintf("claude response json: %s", respJson))
if claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{

View File

@ -63,7 +63,7 @@ const EditChannel = (props) => {
let localModels = [];
switch (value) {
case 14:
localModels = ['claude-instant-1', 'claude-2'];
localModels = ["claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", "claude-3-sonnet-20240229", "claude-3-opus-20240229"];
break;
case 11:
localModels = ['PaLM-2'];