feat: Support stream_options

This commit is contained in:
MartialBE
2024-05-26 19:58:15 +08:00
parent fa54ca7b50
commit eb260652b2
11 changed files with 188 additions and 31 deletions

View File

@@ -1,7 +1,9 @@
package relay
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
@@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error {
r.c.Set("skip_only_chat", true)
}
if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}
r.originalModel = r.chatRequest.Model
return nil
@@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
err = responseStreamClient(r.c, response, r.cache)
doneStr := func() string {
return r.getUsageResponse()
}
err = responseStreamClient(r.c, response, r.cache, doneStr)
} else {
var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
@@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return
}
func (r *relayChat) getUsageResponse() string {
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
usageResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.chatRequest.Model,
Choices: []types.ChatCompletionStreamChoice{},
Usage: r.provider.GetUsage(),
}
responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}
return string(responseBody)
}
return ""
}

View File

@@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil
}
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
type StreamEndHandler func() string
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
@@ -160,6 +162,14 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
cache.NoCache()
}
if errWithOP == nil && endHandler != nil {
streamData := endHandler()
if streamData != "" {
fmt.Fprint(w, "data: "+streamData+"\n\n")
cache.SetResponse(streamData)
}
}
streamData := "data: [DONE]\n"
fmt.Fprint(w, streamData)
cache.SetResponse(streamData)
@@ -167,7 +177,7 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
}
})
return errWithOP
return nil
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {

View File

@@ -1,7 +1,9 @@
package relay
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
@@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error {
return errors.New("max_tokens is invalid")
}
if !r.request.Stream && r.request.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}
r.originalModel = r.request.Model
return nil
@@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
err = responseStreamClient(r.c, response, r.cache)
doneStr := func() string {
return r.getUsageResponse()
}
err = responseStreamClient(r.c, response, r.cache, doneStr)
} else {
var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request)
@@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return
}
func (r *relayCompletions) getUsageResponse() string {
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
usageResponse := types.CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.request.Model,
Choices: []types.CompletionChoice{},
Usage: r.provider.GetUsage(),
}
responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}
return string(responseBody)
}
return ""
}