mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
✨ feat: Support stream_options
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error {
|
||||
r.c.Set("skip_only_chat", true)
|
||||
}
|
||||
|
||||
if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
|
||||
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
|
||||
}
|
||||
|
||||
r.originalModel = r.chatRequest.Model
|
||||
|
||||
return nil
|
||||
@@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
doneStr := func() string {
|
||||
return r.getUsageResponse()
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache, doneStr)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
|
||||
@@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *relayChat) getUsageResponse() string {
|
||||
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
|
||||
usageResponse := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: r.chatRequest.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{},
|
||||
Usage: r.provider.GetUsage(),
|
||||
}
|
||||
|
||||
responseBody, err := json.Marshal(usageResponse)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(responseBody)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||
type StreamEndHandler func() string
|
||||
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
dataChan, errChan := stream.Recv()
|
||||
|
||||
@@ -160,6 +162,14 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
|
||||
cache.NoCache()
|
||||
}
|
||||
|
||||
if errWithOP == nil && endHandler != nil {
|
||||
streamData := endHandler()
|
||||
if streamData != "" {
|
||||
fmt.Fprint(w, "data: "+streamData+"\n\n")
|
||||
cache.SetResponse(streamData)
|
||||
}
|
||||
}
|
||||
|
||||
streamData := "data: [DONE]\n"
|
||||
fmt.Fprint(w, streamData)
|
||||
cache.SetResponse(streamData)
|
||||
@@ -167,7 +177,7 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
|
||||
}
|
||||
})
|
||||
|
||||
return errWithOP
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error {
|
||||
return errors.New("max_tokens is invalid")
|
||||
}
|
||||
|
||||
if !r.request.Stream && r.request.StreamOptions != nil {
|
||||
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
@@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
doneStr := func() string {
|
||||
return r.getUsageResponse()
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache, doneStr)
|
||||
} else {
|
||||
var response *types.CompletionResponse
|
||||
response, err = provider.CreateCompletion(&r.request)
|
||||
@@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *relayCompletions) getUsageResponse() string {
|
||||
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
|
||||
usageResponse := types.CompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: r.request.Model,
|
||||
Choices: []types.CompletionChoice{},
|
||||
Usage: r.provider.GetUsage(),
|
||||
}
|
||||
|
||||
responseBody, err := json.Marshal(usageResponse)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(responseBody)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user