feat: Refactor message handling and update dependencies

- Added required packages to go.mod: `github.com/Laisky/errors/v2 v2.0.1`, `github.com/stretchr/testify v1.8.3`, `github.com/davecgh/go-spew v1.1.1`, `github.com/pmezard/go-difflib v1.0.0`
- Increased the number of returned recordings to 100 in `relay-utils.go`
- Refactored and simplified code in `relay-aiproxy.go` for improved message retrieval and error handling
- Added new test file `relay_test.go` and various test cases for different message types
- Modified functions in `group.go` and `relay.go` for improved functionality
This commit is contained in:
Laisky.Cai
2023-11-17 03:24:55 +00:00
parent 8d270c8c9a
commit 8b477d896d
8 changed files with 98 additions and 60 deletions

View File

@@ -1,13 +1,14 @@
package controller
import (
"errors"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
)
@@ -92,6 +93,8 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
return len(msgs)
case []VisionMessage:
return len(msgs)
case []map[string]any:
return len(msgs)
default:
return 0
}
@@ -99,46 +102,24 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
// TextMessages returns messages as []Message
func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
switch msgs := r.Messages.(type) {
case []any:
messages = make([]Message, 0, len(msgs))
for _, msg := range msgs {
if m, ok := msg.(Message); ok {
messages = append(messages, m)
} else {
err = fmt.Errorf("invalid message type")
return
}
}
case []Message:
messages = msgs
default:
return nil, errors.New("invalid message type")
if blob, err := json.Marshal(r.Messages); err != nil {
return nil, errors.Wrap(err, "marshal messages failed")
} else if err := json.Unmarshal(blob, &messages); err != nil {
return nil, errors.Wrap(err, "unmarshal messages failed")
} else {
return messages, nil
}
return
}
// VisionMessages returns messages as []VisionMessage
func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
switch msgs := r.Messages.(type) {
case []any:
messages = make([]VisionMessage, 0, len(msgs))
for _, msg := range msgs {
if m, ok := msg.(VisionMessage); ok {
messages = append(messages, m)
} else {
err = fmt.Errorf("invalid message type")
return
}
}
case []VisionMessage:
messages = msgs
default:
return nil, errors.New("invalid message type")
if blob, err := json.Marshal(r.Messages); err != nil {
return nil, errors.Wrap(err, "marshal vision messages failed")
} else if err := json.Unmarshal(blob, &messages); err != nil {
return nil, errors.Wrap(err, "unmarshal vision messages failed")
} else {
return messages, nil
}
return
}
func (r GeneralOpenAIRequest) ParseInput() []string {