mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
- Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON. - Enhance the functionality of request conversion functions by including context parameters and robust error wrapping. - Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation.
323 lines
9.9 KiB
Go
323 lines
9.9 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/songquanpeng/one-api/common"
|
|
"github.com/songquanpeng/one-api/common/conv"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/common/render"
|
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
|
"github.com/songquanpeng/one-api/relay/model"
|
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
|
)
|
|
|
|
const (
|
|
dataPrefix = "data: "
|
|
done = "[DONE]"
|
|
dataPrefixLength = len(dataPrefix)
|
|
)
|
|
|
|
// StreamHandler processes streaming responses from OpenAI API
|
|
// It handles incremental content delivery and accumulates the final response text
|
|
// Returns error (if any), accumulated response text, and token usage information
|
|
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
|
// Initialize accumulators for the response
|
|
responseText := ""
|
|
reasoningText := ""
|
|
var usage *model.Usage
|
|
|
|
// Set up scanner for reading the stream line by line
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
|
|
scanner.Buffer(buffer, len(buffer))
|
|
scanner.Split(bufio.ScanLines)
|
|
|
|
// Set response headers for SSE
|
|
common.SetEventStreamHeaders(c)
|
|
|
|
doneRendered := false
|
|
|
|
// Process each line from the stream
|
|
for scanner.Scan() {
|
|
data := NormalizeDataLine(scanner.Text())
|
|
|
|
// Skip lines that don't match expected format
|
|
if len(data) < dataPrefixLength {
|
|
continue // Ignore blank line or wrong format
|
|
}
|
|
|
|
// Verify line starts with expected prefix
|
|
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
|
continue
|
|
}
|
|
|
|
// Check for stream termination
|
|
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
|
render.StringData(c, data)
|
|
doneRendered = true
|
|
continue
|
|
}
|
|
|
|
// Process based on relay mode
|
|
switch relayMode {
|
|
case relaymode.ChatCompletions:
|
|
var streamResponse ChatCompletionsStreamResponse
|
|
|
|
// Parse the JSON response
|
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
|
if err != nil {
|
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
|
render.StringData(c, data) // Pass raw data to client if parsing fails
|
|
continue
|
|
}
|
|
|
|
// Skip empty choices (Azure specific behavior)
|
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
|
continue
|
|
}
|
|
|
|
// Process each choice in the response
|
|
for _, choice := range streamResponse.Choices {
|
|
// Extract reasoning content from different possible fields
|
|
currentReasoningChunk := extractReasoningContent(&choice.Delta)
|
|
|
|
// Update accumulated reasoning text
|
|
if currentReasoningChunk != "" {
|
|
reasoningText += currentReasoningChunk
|
|
}
|
|
|
|
// Set the reasoning content in the format requested by client
|
|
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
|
|
|
|
// Accumulate response content
|
|
responseText += conv.AsString(choice.Delta.Content)
|
|
}
|
|
|
|
// Send the processed data to the client
|
|
render.StringData(c, data)
|
|
|
|
// Update usage information if available
|
|
if streamResponse.Usage != nil {
|
|
usage = streamResponse.Usage
|
|
}
|
|
|
|
case relaymode.Completions:
|
|
// Send the data immediately for Completions mode
|
|
render.StringData(c, data)
|
|
|
|
var streamResponse CompletionsStreamResponse
|
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
|
if err != nil {
|
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
|
continue
|
|
}
|
|
|
|
// Accumulate text from all choices
|
|
for _, choice := range streamResponse.Choices {
|
|
responseText += choice.Text
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for scanner errors
|
|
if err := scanner.Err(); err != nil {
|
|
logger.SysError("error reading stream: " + err.Error())
|
|
}
|
|
|
|
// Ensure stream termination is sent to client
|
|
if !doneRendered {
|
|
render.Done(c)
|
|
}
|
|
|
|
// Clean up resources
|
|
if err := resp.Body.Close(); err != nil {
|
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
|
}
|
|
|
|
// Return the complete response text (reasoning + content) and usage
|
|
return nil, reasoningText + responseText, usage
|
|
}
|
|
|
|
// Helper function to extract reasoning content from message delta
|
|
func extractReasoningContent(delta *model.Message) string {
|
|
content := ""
|
|
|
|
// Extract reasoning from different possible fields
|
|
if delta.Reasoning != nil {
|
|
content += *delta.Reasoning
|
|
delta.Reasoning = nil
|
|
}
|
|
|
|
if delta.ReasoningContent != nil {
|
|
content += *delta.ReasoningContent
|
|
delta.ReasoningContent = nil
|
|
}
|
|
|
|
return content
|
|
}
|
|
|
|
// Handler processes non-streaming responses from OpenAI API
|
|
// Returns error (if any) and token usage information
|
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
|
// Read the entire response body
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// Close the original response body
|
|
if err = resp.Body.Close(); err != nil {
|
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// Parse the response JSON
|
|
var textResponse SlimTextResponse
|
|
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
|
|
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// Check for API errors
|
|
if textResponse.Error.Type != "" {
|
|
return &model.ErrorWithStatusCode{
|
|
Error: textResponse.Error,
|
|
StatusCode: resp.StatusCode,
|
|
}, nil
|
|
}
|
|
|
|
// Process reasoning content in each choice
|
|
for _, msg := range textResponse.Choices {
|
|
reasoningContent := processReasoningContent(&msg)
|
|
|
|
// Set reasoning in requested format if content exists
|
|
if reasoningContent != "" {
|
|
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
|
|
}
|
|
}
|
|
|
|
// Reset response body for forwarding to client
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
|
|
|
// Forward all response headers (not just first value of each)
|
|
for k, values := range resp.Header {
|
|
for _, v := range values {
|
|
c.Writer.Header().Add(k, v)
|
|
}
|
|
}
|
|
|
|
// Set response status and copy body to client
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// Close the reset body
|
|
if err = resp.Body.Close(); err != nil {
|
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
|
|
// Calculate token usage if not provided by API
|
|
calculateTokenUsage(&textResponse, promptTokens, modelName)
|
|
|
|
return nil, &textResponse.Usage
|
|
}
|
|
|
|
// processReasoningContent is a helper function to extract and process reasoning content from the message
|
|
func processReasoningContent(msg *TextResponseChoice) string {
|
|
var reasoningContent string
|
|
|
|
// Check different locations for reasoning content
|
|
switch {
|
|
case msg.Reasoning != nil:
|
|
reasoningContent = *msg.Reasoning
|
|
msg.Reasoning = nil
|
|
case msg.ReasoningContent != nil:
|
|
reasoningContent = *msg.ReasoningContent
|
|
msg.ReasoningContent = nil
|
|
case msg.Message.Reasoning != nil:
|
|
reasoningContent = *msg.Message.Reasoning
|
|
msg.Message.Reasoning = nil
|
|
case msg.Message.ReasoningContent != nil:
|
|
reasoningContent = *msg.Message.ReasoningContent
|
|
msg.Message.ReasoningContent = nil
|
|
}
|
|
|
|
return reasoningContent
|
|
}
|
|
|
|
// Helper function to calculate token usage
|
|
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
|
|
// Calculate tokens if not provided by the API
|
|
if response.Usage.TotalTokens == 0 ||
|
|
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
|
|
|
|
completionTokens := 0
|
|
for _, choice := range response.Choices {
|
|
// Count content tokens
|
|
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
|
|
|
// Count reasoning tokens in all possible locations
|
|
if choice.Message.Reasoning != nil {
|
|
completionTokens += CountToken(*choice.Message.Reasoning)
|
|
}
|
|
if choice.Message.ReasoningContent != nil {
|
|
completionTokens += CountToken(*choice.Message.ReasoningContent)
|
|
}
|
|
if choice.Reasoning != nil {
|
|
completionTokens += CountToken(*choice.Reasoning)
|
|
}
|
|
if choice.ReasoningContent != nil {
|
|
completionTokens += CountToken(*choice.ReasoningContent)
|
|
}
|
|
}
|
|
|
|
// Set usage values
|
|
response.Usage = model.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
}
|
|
} else if hasAudioTokens(response) {
|
|
// Handle audio tokens conversion
|
|
calculateAudioTokens(response, modelName)
|
|
}
|
|
}
|
|
|
|
// Helper function to check if response has audio tokens
|
|
func hasAudioTokens(response *SlimTextResponse) bool {
|
|
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
|
|
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
|
|
}
|
|
|
|
// Helper function to calculate audio token usage
|
|
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
|
|
// Convert audio tokens for prompt
|
|
if response.PromptTokensDetails != nil {
|
|
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
|
|
int(math.Ceil(
|
|
float64(response.PromptTokensDetails.AudioTokens)*
|
|
ratio.GetAudioPromptRatio(modelName),
|
|
))
|
|
}
|
|
|
|
// Convert audio tokens for completion
|
|
if response.CompletionTokensDetails != nil {
|
|
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
|
|
int(math.Ceil(
|
|
float64(response.CompletionTokensDetails.AudioTokens)*
|
|
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
|
))
|
|
}
|
|
|
|
// Calculate total tokens
|
|
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
|
|
}
|