refactor: Refactor: Consolidate Anthropic model requests into GeneralOpenAIRequest

- Refactor Anthropic adapter to work with the new Anthropic API and model requests
- Remove the default value for `MaxTokensToSample`
- Set `MaxTokens` to 500 instead of 1000000
- Use `system` messages as the system prompt instead of the first message
This commit is contained in:
Laisky.Cai 2024-03-05 03:38:26 +00:00
parent fb23ea0c9a
commit bb8755bc98
4 changed files with 26 additions and 40 deletions

View File

@ -3,13 +3,14 @@ package anthropic
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
) )
type Adaptor struct { type Adaptor struct {

View File

@ -28,37 +28,25 @@ func stopReasonClaude2OpenAI(reason string) string {
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeRequest := Request{ claudeRequest := Request{
Model: textRequest.Model, GeneralOpenAIRequest: textRequest,
Prompt: "",
MaxTokensToSample: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
Stream: textRequest.Stream,
} }
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 1000000 if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 500 // max_tokens is required
} }
prompt := ""
// messages, err := textRequest.TextMessages() // anthropic's new messages API use system to represent the system prompt
// if err != nil { var filteredMessages []model.Message
// log.Panicf("invalid message type: %T", textRequest.Messages) for _, msg := range claudeRequest.Messages {
// } if msg.Role != "system" {
filteredMessages = append(filteredMessages, msg)
for _, message := range textRequest.Messages { continue
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
} }
claudeRequest.System += msg.Content.(string)
} }
prompt += "\n\nAssistant:" claudeRequest.Messages = filteredMessages
claudeRequest.Prompt = prompt
return &claudeRequest return &claudeRequest
} }

View File

@ -1,19 +1,17 @@
package anthropic package anthropic
import (
"github.com/songquanpeng/one-api/relay/model"
)
type Metadata struct { type Metadata struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
} }
type Request struct { type Request struct {
Model string `json:"model"` model.GeneralOpenAIRequest
Prompt string `json:"prompt"` // System anthropic messages API use system to represent the system prompt
MaxTokensToSample int `json:"max_tokens_to_sample"` System string `json:"system"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
} }
type Error struct { type Error struct {

View File

@ -3,7 +3,7 @@ package helper
import ( import (
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/aiproxy" "github.com/songquanpeng/one-api/relay/channel/aiproxy"
// "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/anthropic"
"github.com/songquanpeng/one-api/relay/channel/gemini" "github.com/songquanpeng/one-api/relay/channel/gemini"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/channel/palm"
@ -17,8 +17,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
// case constant.APITypeAli: // case constant.APITypeAli:
// return &ali.Adaptor{} // return &ali.Adaptor{}
case constant.APITypeAnthropic: case constant.APITypeAnthropic:
// return &anthropic.Adaptor{} return &anthropic.Adaptor{}
return &openai.Adaptor{}
// case constant.APITypeBaidu: // case constant.APITypeBaidu:
// return &baidu.Adaptor{} // return &baidu.Adaptor{}
case constant.APITypeGemini: case constant.APITypeGemini: