feat: Refactor message handling and update dependencies

- Added required packages to go.mod: `github.com/Laisky/errors/v2 v2.0.1`, `github.com/stretchr/testify v1.8.3`, `github.com/davecgh/go-spew v1.1.1`, `github.com/pmezard/go-difflib v1.0.0`
- Increased the number of returned recordings to 100 in `relay-utils.go`
- Refactored and simplified code in `relay-aiproxy.go` for improved message retrieval and error handling
- Added new test file `relay_test.go` and various test cases for different message types
- Modified functions in `group.go` and `relay.go` for improved functionality
This commit is contained in:
Laisky.Cai 2023-11-17 03:24:55 +00:00
parent 8d270c8c9a
commit 8b477d896d
8 changed files with 98 additions and 60 deletions

View File

@ -23,17 +23,17 @@ func printHelp() {
} }
func init() { func init() {
flag.Parse() // flag.Parse()
if *PrintVersion { // if *PrintVersion {
fmt.Println(Version) // fmt.Println(Version)
os.Exit(0) // os.Exit(0)
} // }
if *PrintHelp { // if *PrintHelp {
printHelp() // printHelp()
os.Exit(0) // os.Exit(0)
} // }
if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET") SessionSecret = os.Getenv("SESSION_SECRET")

View File

@ -8,7 +8,7 @@ import (
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio { for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -49,23 +49,13 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := "" query := ""
if request.MessagesLen() != 0 { if request.MessagesLen() != 0 {
switch msgs := request.Messages.(type) { if msgs, err := request.TextMessages(); err == nil {
case []Message:
query = msgs[len(msgs)-1].Content query = msgs[len(msgs)-1].Content
case []VisionMessage: } else if msgs, err := request.VisionMessages(); err == nil {
query = msgs[len(msgs)-1].Content.Text query = msgs[len(msgs)-1].Content.Text
case []any: } else {
msg := msgs[len(msgs)-1]
switch msg := msg.(type) {
case Message:
query = msg.Content
case VisionMessage:
query = msg.Content.Text
default:
log.Panicf("unknown message type: %T", msg)
}
default:
log.Panicf("unknown message type: %T", msgs) log.Panicf("unknown message type: %T", msgs)
} }
} }

View File

@ -29,7 +29,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model, _ := range common.ModelRatio { for model := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {

View File

@ -1,13 +1,14 @@
package controller package controller
import ( import (
"errors" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -92,6 +93,8 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
return len(msgs) return len(msgs)
case []VisionMessage: case []VisionMessage:
return len(msgs) return len(msgs)
case []map[string]any:
return len(msgs)
default: default:
return 0 return 0
} }
@ -99,46 +102,24 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
// TextMessages returns messages as []Message // TextMessages returns messages as []Message
func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) { func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
switch msgs := r.Messages.(type) { if blob, err := json.Marshal(r.Messages); err != nil {
case []any: return nil, errors.Wrap(err, "marshal messages failed")
messages = make([]Message, 0, len(msgs)) } else if err := json.Unmarshal(blob, &messages); err != nil {
for _, msg := range msgs { return nil, errors.Wrap(err, "unmarshal messages failed")
if m, ok := msg.(Message); ok {
messages = append(messages, m)
} else { } else {
err = fmt.Errorf("invalid message type") return messages, nil
return
} }
}
case []Message:
messages = msgs
default:
return nil, errors.New("invalid message type")
}
return
} }
// VisionMessages returns messages as []VisionMessage // VisionMessages returns messages as []VisionMessage
func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) { func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
switch msgs := r.Messages.(type) { if blob, err := json.Marshal(r.Messages); err != nil {
case []any: return nil, errors.Wrap(err, "marshal vision messages failed")
messages = make([]VisionMessage, 0, len(msgs)) } else if err := json.Unmarshal(blob, &messages); err != nil {
for _, msg := range msgs { return nil, errors.Wrap(err, "unmarshal vision messages failed")
if m, ok := msg.(VisionMessage); ok {
messages = append(messages, m)
} else { } else {
err = fmt.Errorf("invalid message type") return messages, nil
return
} }
}
case []VisionMessage:
messages = msgs
default:
return nil, errors.New("invalid message type")
}
return
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {

61
controller/relay_test.go Normal file
View File

@ -0,0 +1,61 @@
package controller
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestGeneralOpenAIRequest_TextMessages(t *testing.T) {
tests := []struct {
name string
messages interface{}
want []Message
wantErr error
}{
{
name: "Test with []any messages",
messages: []any{Message{}, Message{}},
want: []Message{{}, {}},
wantErr: nil,
},
{
name: "Test with []Message messages",
messages: []Message{{}, {}},
want: []Message{{}, {}},
wantErr: nil,
},
{
name: "Test with invalid message type",
messages: "invalid",
want: nil,
wantErr: fmt.Errorf("invalid message type string"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &GeneralOpenAIRequest{
Messages: tt.messages,
}
got := new(GeneralOpenAIRequest)
blob, err := json.Marshal(r)
require.NoError(t, err)
err = json.Unmarshal(blob, got)
require.NoError(t, err)
gotMessages, err := got.TextMessages()
if tt.wantErr != nil {
require.ErrorContains(t, err, "cannot unmarshal string into Go value")
return
} else {
require.NoError(t, err)
}
require.Equal(t, tt.want, gotMessages)
})
}
}

4
go.mod
View File

@ -3,6 +3,7 @@ module one-api
go 1.21 go 1.21
require ( require (
github.com/Laisky/errors/v2 v2.0.1
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
@ -12,6 +13,7 @@ require (
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.4.0 github.com/google/uuid v1.4.0
github.com/pkoukk/tiktoken-go v0.1.6 github.com/pkoukk/tiktoken-go v0.1.6
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.15.0 golang.org/x/crypto v0.15.0
gorm.io/driver/mysql v1.5.2 gorm.io/driver/mysql v1.5.2
gorm.io/driver/postgres v1.5.4 gorm.io/driver/postgres v1.5.4
@ -23,6 +25,7 @@ require (
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@ -47,6 +50,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect

2
go.sum
View File

@ -1,3 +1,5 @@
github.com/Laisky/errors/v2 v2.0.1 h1:yqCBrRzaP012AMB+7fVlXrP34OWRHrSO/hZ38CFdH84=
github.com/Laisky/errors/v2 v2.0.1/go.mod h1:mTn1LHSmKm4CYug0rpYO7rz13dp/DKrtzlSELSrxvT0=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=