Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2023-12-12 06:00:29 +00:00
31 changed files with 1582 additions and 387 deletions

View File

@@ -1,20 +1,18 @@
package controller
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"strconv"
"strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
}
@@ -56,6 +54,45 @@ type OpenaiVisionMessageContentImageUrl struct {
Detail VisionImageResolution `json:"detail,omitempty"`
}
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -64,63 +101,75 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeAudio
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
// Messages maybe []Message or []VisionMessage
Messages any `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
func (r *GeneralOpenAIRequest) MessagesLen() int {
switch msgs := r.Messages.(type) {
case []any:
return len(msgs)
case []Message:
return len(msgs)
case []VisionMessage:
return len(msgs)
case []map[string]any:
return len(msgs)
default:
return 0
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
// func (r *GeneralOpenAIRequest) MessagesLen() int {
// switch msgs := r.Messages.(type) {
// case []any:
// return len(msgs)
// case []Message:
// return len(msgs)
// case []VisionMessage:
// return len(msgs)
// case []map[string]any:
// return len(msgs)
// default:
// return 0
// }
// }
// TextMessages returns messages as []Message
func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
if blob, err := json.Marshal(r.Messages); err != nil {
return nil, errors.Wrap(err, "marshal messages failed")
} else if err := json.Unmarshal(blob, &messages); err != nil {
return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob))
} else {
return messages, nil
}
}
// func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) {
// if blob, err := json.Marshal(r.Messages); err != nil {
// return nil, errors.Wrap(err, "marshal messages failed")
// } else if err := json.Unmarshal(blob, &messages); err != nil {
// return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob))
// } else {
// return messages, nil
// }
// }
// VisionMessages returns messages as []VisionMessage
func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
if blob, err := json.Marshal(r.Messages); err != nil {
return nil, errors.Wrap(err, "marshal vision messages failed")
} else if err := json.Unmarshal(blob, &messages); err != nil {
return nil, errors.Wrapf(err, "unmarshal vision messages failed %q", string(blob))
} else {
return messages, nil
}
}
// func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) {
// if blob, err := json.Marshal(r.Messages); err != nil {
// return nil, errors.Wrap(err, "marshal vision messages failed")
// } else if err := json.Unmarshal(blob, &messages); err != nil {
// return nil, errors.Wrapf(err, "unmarshal vision messages failed %q", string(blob))
// } else {
// return messages, nil
// }
// }
func (r GeneralOpenAIRequest) ParseInput() []string {
if r.Input == nil {
@@ -155,16 +204,51 @@ type TextRequest struct {
//Stream bool `json:"stream"`
}
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}
type AudioResponse struct {
type WhisperJSONResponse struct {
Text string `json:"text,omitempty"`
}
type WhisperVerboseJSONResponse struct {
Task string `json:"task,omitempty"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
Text string `json:"text,omitempty"`
Segments []Segment `json:"segments,omitempty"`
}
type Segment struct {
Id int `json:"id"`
Seek int `json:"seek"`
Start float64 `json:"start"`
End float64 `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
Temperature float64 `json:"temperature"`
AvgLogprob float64 `json:"avg_logprob"`
CompressionRatio float64 `json:"compression_ratio"`
NoSpeechProb float64 `json:"no_speech_prob"`
}
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
@@ -261,14 +345,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudio:
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)