diff --git a/common/init.go b/common/init.go index 1e9c85ce..e04612d5 100644 --- a/common/init.go +++ b/common/init.go @@ -23,17 +23,17 @@ func printHelp() { } func init() { - flag.Parse() + // flag.Parse() - if *PrintVersion { - fmt.Println(Version) - os.Exit(0) - } + // if *PrintVersion { + // fmt.Println(Version) + // os.Exit(0) + // } - if *PrintHelp { - printHelp() - os.Exit(0) - } + // if *PrintHelp { + // printHelp() + // os.Exit(0) + // } if os.Getenv("SESSION_SECRET") != "" { SessionSecret = os.Getenv("SESSION_SECRET") diff --git a/controller/group.go b/controller/group.go index 2b2f6006..d959bd37 100644 --- a/controller/group.go +++ b/controller/group.go @@ -8,7 +8,7 @@ import ( func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName, _ := range common.GroupRatio { + for groupName := range common.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index ca0911ba..8a664b0b 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -49,23 +49,13 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" + if request.MessagesLen() != 0 { - switch msgs := request.Messages.(type) { - case []Message: + if msgs, err := request.TextMessages(); err == nil { query = msgs[len(msgs)-1].Content - case []VisionMessage: + } else if msgs, err := request.VisionMessages(); err == nil { query = msgs[len(msgs)-1].Content.Text - case []any: - msg := msgs[len(msgs)-1] - switch msg := msg.(type) { - case Message: - query = msg.Content - case VisionMessage: - query = msg.Content.Text - default: - log.Panicf("unknown message type: %T", msg) - } - default: + } else { log.Panicf("unknown message type: %T", msgs) } } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..407d876b 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -29,7 +29,7 @@ func InitTokenEncoders() { if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model, _ := range common.ModelRatio { + for model := range common.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/controller/relay.go b/controller/relay.go index bed5a2e2..5d41d10b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,13 +1,14 @@ package controller import ( - "errors" + "encoding/json" "fmt" "net/http" "one-api/common" "strconv" "strings" + "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" ) @@ -92,6 +93,8 @@ func (r *GeneralOpenAIRequest) MessagesLen() int { return len(msgs) case []VisionMessage: return len(msgs) + case []map[string]any: + return len(msgs) default: return 0 } @@ -99,46 +102,24 @@ func (r *GeneralOpenAIRequest) MessagesLen() int { // TextMessages returns messages as []Message func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) { - switch msgs := r.Messages.(type) { - case []any: - messages = make([]Message, 0, len(msgs)) - for _, msg := range msgs { - if m, ok := msg.(Message); ok { - messages = append(messages, m) - } else { - err = fmt.Errorf("invalid message type") - return - } - } - case []Message: - messages = msgs - default: - return nil, errors.New("invalid message type") + if blob, err := json.Marshal(r.Messages); err != nil { + return nil, errors.Wrap(err, "marshal messages failed") + } else if err := json.Unmarshal(blob, &messages); err != nil { + return nil, errors.Wrap(err, "unmarshal messages failed") + } else { + return messages, nil } - - return } // VisionMessages returns messages as []VisionMessage func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) { - switch msgs := r.Messages.(type) { - case []any: - messages = make([]VisionMessage, 0, len(msgs)) - for _, msg := range msgs { - if m, ok := msg.(VisionMessage); ok { - messages = append(messages, m) - } else { - err = fmt.Errorf("invalid message type") - return - } - } - case []VisionMessage: - messages = msgs - default: - return nil, errors.New("invalid message type") + if blob, err := json.Marshal(r.Messages); err != nil { + return nil, errors.Wrap(err, "marshal vision messages failed") + } else if err := json.Unmarshal(blob, &messages); err != nil { + return nil, errors.Wrap(err, "unmarshal vision messages failed") + } else { + return messages, nil } - - return } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/controller/relay_test.go b/controller/relay_test.go new file mode 100644 index 00000000..9bb41c9a --- /dev/null +++ b/controller/relay_test.go @@ -0,0 +1,61 @@ +package controller + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGeneralOpenAIRequest_TextMessages(t *testing.T) { + tests := []struct { + name string + messages interface{} + want []Message + wantErr error + }{ + { + name: "Test with []any messages", + messages: []any{Message{}, Message{}}, + want: []Message{{}, {}}, + wantErr: nil, + }, + { + name: "Test with []Message messages", + messages: []Message{{}, {}}, + want: []Message{{}, {}}, + wantErr: nil, + }, + { + name: "Test with invalid message type", + messages: "invalid", + want: nil, + wantErr: fmt.Errorf("invalid message type string"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &GeneralOpenAIRequest{ + Messages: tt.messages, + } + got := new(GeneralOpenAIRequest) + + blob, err := json.Marshal(r) + require.NoError(t, err) + err = json.Unmarshal(blob, got) + require.NoError(t, err) + + gotMessages, err := got.TextMessages() + if tt.wantErr != nil { + require.ErrorContains(t, err, "cannot unmarshal string into Go value") + return + } else { + require.NoError(t, err) + } + + require.Equal(t, tt.want, gotMessages) + }) + } +} diff --git a/go.mod b/go.mod index 3922c5d4..c4dd5685 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module one-api go 1.21 require ( + github.com/Laisky/errors/v2 v2.0.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -12,6 +13,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.4.0 github.com/pkoukk/tiktoken-go v0.1.6 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.15.0 gorm.io/driver/mysql v1.5.2 gorm.io/driver/postgres v1.5.4 @@ -23,6 +25,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -47,6 +50,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect diff --git a/go.sum b/go.sum index 5257f85b..b78729ef 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/Laisky/errors/v2 v2.0.1 h1:yqCBrRzaP012AMB+7fVlXrP34OWRHrSO/hZ38CFdH84= +github.com/Laisky/errors/v2 v2.0.1/go.mod h1:mTn1LHSmKm4CYug0rpYO7rz13dp/DKrtzlSELSrxvT0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=