one-api/relay/adaptor/openai/main.go
Laisky.Cai f6cfe7cd4f feat: enhance error handling and reasoning mechanisms across middleware
- Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON.
- Enhance the functionality of request conversion functions by including context parameters and robust error wrapping.
- Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation.
2025-02-26 05:38:21 +00:00

323 lines
9.9 KiB
Go

package openai
import (
"bufio"
"bytes"
"encoding/json"
"io"
"math"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
const (
dataPrefix = "data: "
done = "[DONE]"
dataPrefixLength = len(dataPrefix)
)
// StreamHandler processes streaming responses from OpenAI API
// It handles incremental content delivery and accumulates the final response text
// Returns error (if any), accumulated response text, and token usage information
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
// Initialize accumulators for the response
responseText := ""
reasoningText := ""
var usage *model.Usage
// Set up scanner for reading the stream line by line
scanner := bufio.NewScanner(resp.Body)
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
scanner.Buffer(buffer, len(buffer))
scanner.Split(bufio.ScanLines)
// Set response headers for SSE
common.SetEventStreamHeaders(c)
doneRendered := false
// Process each line from the stream
for scanner.Scan() {
data := NormalizeDataLine(scanner.Text())
// Skip lines that don't match expected format
if len(data) < dataPrefixLength {
continue // Ignore blank line or wrong format
}
// Verify line starts with expected prefix
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue
}
// Check for stream termination
if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data)
doneRendered = true
continue
}
// Process based on relay mode
switch relayMode {
case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
// Parse the JSON response
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // Pass raw data to client if parsing fails
continue
}
// Skip empty choices (Azure specific behavior)
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
continue
}
// Process each choice in the response
for _, choice := range streamResponse.Choices {
// Extract reasoning content from different possible fields
currentReasoningChunk := extractReasoningContent(&choice.Delta)
// Update accumulated reasoning text
if currentReasoningChunk != "" {
reasoningText += currentReasoningChunk
}
// Set the reasoning content in the format requested by client
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
// Accumulate response content
responseText += conv.AsString(choice.Delta.Content)
}
// Send the processed data to the client
render.StringData(c, data)
// Update usage information if available
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
case relaymode.Completions:
// Send the data immediately for Completions mode
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
continue
}
// Accumulate text from all choices
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
// Check for scanner errors
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
}
// Ensure stream termination is sent to client
if !doneRendered {
render.Done(c)
}
// Clean up resources
if err := resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
}
// Return the complete response text (reasoning + content) and usage
return nil, reasoningText + responseText, usage
}
// Helper function to extract reasoning content from message delta
func extractReasoningContent(delta *model.Message) string {
content := ""
// Extract reasoning from different possible fields
if delta.Reasoning != nil {
content += *delta.Reasoning
delta.Reasoning = nil
}
if delta.ReasoningContent != nil {
content += *delta.ReasoningContent
delta.ReasoningContent = nil
}
return content
}
// Handler processes non-streaming responses from OpenAI API
// Returns error (if any) and token usage information
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
// Read the entire response body
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
// Close the original response body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Parse the response JSON
var textResponse SlimTextResponse
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
// Check for API errors
if textResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{
Error: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Process reasoning content in each choice
for _, msg := range textResponse.Choices {
reasoningContent := processReasoningContent(&msg)
// Set reasoning in requested format if content exists
if reasoningContent != "" {
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
}
}
// Reset response body for forwarding to client
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
// Forward all response headers (not just first value of each)
for k, values := range resp.Header {
for _, v := range values {
c.Writer.Header().Add(k, v)
}
}
// Set response status and copy body to client
c.Writer.WriteHeader(resp.StatusCode)
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
// Close the reset body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// Calculate token usage if not provided by API
calculateTokenUsage(&textResponse, promptTokens, modelName)
return nil, &textResponse.Usage
}
// processReasoningContent is a helper function to extract and process reasoning content from the message
func processReasoningContent(msg *TextResponseChoice) string {
var reasoningContent string
// Check different locations for reasoning content
switch {
case msg.Reasoning != nil:
reasoningContent = *msg.Reasoning
msg.Reasoning = nil
case msg.ReasoningContent != nil:
reasoningContent = *msg.ReasoningContent
msg.ReasoningContent = nil
case msg.Message.Reasoning != nil:
reasoningContent = *msg.Message.Reasoning
msg.Message.Reasoning = nil
case msg.Message.ReasoningContent != nil:
reasoningContent = *msg.Message.ReasoningContent
msg.Message.ReasoningContent = nil
}
return reasoningContent
}
// Helper function to calculate token usage
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
// Calculate tokens if not provided by the API
if response.Usage.TotalTokens == 0 ||
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range response.Choices {
// Count content tokens
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
// Count reasoning tokens in all possible locations
if choice.Message.Reasoning != nil {
completionTokens += CountToken(*choice.Message.Reasoning)
}
if choice.Message.ReasoningContent != nil {
completionTokens += CountToken(*choice.Message.ReasoningContent)
}
if choice.Reasoning != nil {
completionTokens += CountToken(*choice.Reasoning)
}
if choice.ReasoningContent != nil {
completionTokens += CountToken(*choice.ReasoningContent)
}
}
// Set usage values
response.Usage = model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
} else if hasAudioTokens(response) {
// Handle audio tokens conversion
calculateAudioTokens(response, modelName)
}
}
// Helper function to check if response has audio tokens
func hasAudioTokens(response *SlimTextResponse) bool {
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
}
// Helper function to calculate audio token usage
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
// Convert audio tokens for prompt
if response.PromptTokensDetails != nil {
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
int(math.Ceil(
float64(response.PromptTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName),
))
}
// Convert audio tokens for completion
if response.CompletionTokensDetails != nil {
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
int(math.Ceil(
float64(response.CompletionTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
))
}
// Calculate total tokens
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
}