mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
feat: Refactor message handling and update dependencies
- Added required packages to go.mod: `github.com/Laisky/errors/v2 v2.0.1`, `github.com/stretchr/testify v1.8.3`, `github.com/davecgh/go-spew v1.1.1`, `github.com/pmezard/go-difflib v1.0.0` - Increased the number of returned recordings to 100 in `relay-utils.go` - Refactored and simplified code in `relay-aiproxy.go` for improved message retrieval and error handling - Added new test file `relay_test.go` and various test cases for different message types - Modified functions in `group.go` and `relay.go` for improved functionality
This commit is contained in:
parent
8d270c8c9a
commit
8b477d896d
@ -23,17 +23,17 @@ func printHelp() {
|
||||
}
|
||||
|
||||
func init() {
|
||||
flag.Parse()
|
||||
// flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
}
|
||||
// if *PrintVersion {
|
||||
// fmt.Println(Version)
|
||||
// os.Exit(0)
|
||||
// }
|
||||
|
||||
if *PrintHelp {
|
||||
printHelp()
|
||||
os.Exit(0)
|
||||
}
|
||||
// if *PrintHelp {
|
||||
// printHelp()
|
||||
// os.Exit(0)
|
||||
// }
|
||||
|
||||
if os.Getenv("SESSION_SECRET") != "" {
|
||||
SessionSecret = os.Getenv("SESSION_SECRET")
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
for groupName := range common.GroupRatio {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
@ -49,23 +49,13 @@ type AIProxyLibraryStreamResponse struct {
|
||||
|
||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||
query := ""
|
||||
|
||||
if request.MessagesLen() != 0 {
|
||||
switch msgs := request.Messages.(type) {
|
||||
case []Message:
|
||||
if msgs, err := request.TextMessages(); err == nil {
|
||||
query = msgs[len(msgs)-1].Content
|
||||
case []VisionMessage:
|
||||
} else if msgs, err := request.VisionMessages(); err == nil {
|
||||
query = msgs[len(msgs)-1].Content.Text
|
||||
case []any:
|
||||
msg := msgs[len(msgs)-1]
|
||||
switch msg := msg.(type) {
|
||||
case Message:
|
||||
query = msg.Content
|
||||
case VisionMessage:
|
||||
query = msg.Content.Text
|
||||
default:
|
||||
log.Panicf("unknown message type: %T", msg)
|
||||
}
|
||||
default:
|
||||
} else {
|
||||
log.Panicf("unknown message type: %T", msgs)
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.ModelRatio {
|
||||
for model := range common.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
|
@ -1,13 +1,14 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@ -92,6 +93,8 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
|
||||
return len(msgs)
|
||||
case []VisionMessage:
|
||||
return len(msgs)
|
||||
case []map[string]any:
|
||||
return len(msgs)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@ -99,46 +102,24 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
|
||||
|
||||
// TextMessages returns messages as []Message
|
||||
func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
|
||||
switch msgs := r.Messages.(type) {
|
||||
case []any:
|
||||
messages = make([]Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if m, ok := msg.(Message); ok {
|
||||
messages = append(messages, m)
|
||||
} else {
|
||||
err = fmt.Errorf("invalid message type")
|
||||
return
|
||||
}
|
||||
}
|
||||
case []Message:
|
||||
messages = msgs
|
||||
default:
|
||||
return nil, errors.New("invalid message type")
|
||||
if blob, err := json.Marshal(r.Messages); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal messages failed")
|
||||
} else if err := json.Unmarshal(blob, &messages); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal messages failed")
|
||||
} else {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// VisionMessages returns messages as []VisionMessage
|
||||
func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
|
||||
switch msgs := r.Messages.(type) {
|
||||
case []any:
|
||||
messages = make([]VisionMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if m, ok := msg.(VisionMessage); ok {
|
||||
messages = append(messages, m)
|
||||
} else {
|
||||
err = fmt.Errorf("invalid message type")
|
||||
return
|
||||
}
|
||||
}
|
||||
case []VisionMessage:
|
||||
messages = msgs
|
||||
default:
|
||||
return nil, errors.New("invalid message type")
|
||||
if blob, err := json.Marshal(r.Messages); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal vision messages failed")
|
||||
} else if err := json.Unmarshal(blob, &messages); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vision messages failed")
|
||||
} else {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
61
controller/relay_test.go
Normal file
61
controller/relay_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeneralOpenAIRequest_TextMessages(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages interface{}
|
||||
want []Message
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Test with []any messages",
|
||||
messages: []any{Message{}, Message{}},
|
||||
want: []Message{{}, {}},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Test with []Message messages",
|
||||
messages: []Message{{}, {}},
|
||||
want: []Message{{}, {}},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Test with invalid message type",
|
||||
messages: "invalid",
|
||||
want: nil,
|
||||
wantErr: fmt.Errorf("invalid message type string"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &GeneralOpenAIRequest{
|
||||
Messages: tt.messages,
|
||||
}
|
||||
got := new(GeneralOpenAIRequest)
|
||||
|
||||
blob, err := json.Marshal(r)
|
||||
require.NoError(t, err)
|
||||
err = json.Unmarshal(blob, got)
|
||||
require.NoError(t, err)
|
||||
|
||||
gotMessages, err := got.TextMessages()
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorContains(t, err, "cannot unmarshal string into Go value")
|
||||
return
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want, gotMessages)
|
||||
})
|
||||
}
|
||||
}
|
4
go.mod
4
go.mod
@ -3,6 +3,7 @@ module one-api
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/Laisky/errors/v2 v2.0.1
|
||||
github.com/gin-contrib/cors v1.4.0
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@ -12,6 +13,7 @@ require (
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/google/uuid v1.4.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.6
|
||||
github.com/stretchr/testify v1.8.3
|
||||
golang.org/x/crypto v0.15.0
|
||||
gorm.io/driver/mysql v1.5.2
|
||||
gorm.io/driver/postgres v1.5.4
|
||||
@ -23,6 +25,7 @@ require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
@ -47,6 +50,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -1,3 +1,5 @@
|
||||
github.com/Laisky/errors/v2 v2.0.1 h1:yqCBrRzaP012AMB+7fVlXrP34OWRHrSO/hZ38CFdH84=
|
||||
github.com/Laisky/errors/v2 v2.0.1/go.mod h1:mTn1LHSmKm4CYug0rpYO7rz13dp/DKrtzlSELSrxvT0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
|
Loading…
Reference in New Issue
Block a user