mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-26 17:55:58 +08:00
feat: Refactor message handling and update dependencies
- Added required packages to go.mod: `github.com/Laisky/errors/v2 v2.0.1`, `github.com/stretchr/testify v1.8.3`, `github.com/davecgh/go-spew v1.1.1`, `github.com/pmezard/go-difflib v1.0.0` - Increased the number of returned recordings to 100 in `relay-utils.go` - Refactored and simplified code in `relay-aiproxy.go` for improved message retrieval and error handling - Added new test file `relay_test.go` and various test cases for different message types - Modified functions in `group.go` and `relay.go` for improved functionality
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -92,6 +93,8 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
|
||||
return len(msgs)
|
||||
case []VisionMessage:
|
||||
return len(msgs)
|
||||
case []map[string]any:
|
||||
return len(msgs)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -99,46 +102,24 @@ func (r *GeneralOpenAIRequest) MessagesLen() int {
|
||||
|
||||
// TextMessages returns messages as []Message
|
||||
func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
|
||||
switch msgs := r.Messages.(type) {
|
||||
case []any:
|
||||
messages = make([]Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if m, ok := msg.(Message); ok {
|
||||
messages = append(messages, m)
|
||||
} else {
|
||||
err = fmt.Errorf("invalid message type")
|
||||
return
|
||||
}
|
||||
}
|
||||
case []Message:
|
||||
messages = msgs
|
||||
default:
|
||||
return nil, errors.New("invalid message type")
|
||||
if blob, err := json.Marshal(r.Messages); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal messages failed")
|
||||
} else if err := json.Unmarshal(blob, &messages); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal messages failed")
|
||||
} else {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// VisionMessages returns messages as []VisionMessage
|
||||
func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
|
||||
switch msgs := r.Messages.(type) {
|
||||
case []any:
|
||||
messages = make([]VisionMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if m, ok := msg.(VisionMessage); ok {
|
||||
messages = append(messages, m)
|
||||
} else {
|
||||
err = fmt.Errorf("invalid message type")
|
||||
return
|
||||
}
|
||||
}
|
||||
case []VisionMessage:
|
||||
messages = msgs
|
||||
default:
|
||||
return nil, errors.New("invalid message type")
|
||||
if blob, err := json.Marshal(r.Messages); err != nil {
|
||||
return nil, errors.Wrap(err, "marshal vision messages failed")
|
||||
} else if err := json.Unmarshal(blob, &messages); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal vision messages failed")
|
||||
} else {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
||||
Reference in New Issue
Block a user