Add support for processing <think> tags in both streaming and non-streaming responses

This commit is contained in:
Ben Gao 2025-03-21 22:49:47 +08:00
parent 8df4a2670b
commit 768a1be11a
3 changed files with 360 additions and 7 deletions

View File

@ -32,6 +32,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
common.SetEventStreamHeaders(c)
// Variables to track <think> tag state across chunks
inThinkTag := false
var reasoningBuilder strings.Builder
doneRendered := false
for scanner.Scan() {
data := scanner.Text()
@ -52,14 +56,50 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
render.StringData(c, data)
continue
}
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice and no usage, we should not pass it to client, this is for azure
continue // just ignore empty choice
continue
}
render.StringData(c, data)
// Process <think> tags before rendering
for i := range streamResponse.Choices {
if streamResponse.Choices[i].Delta.Content != nil {
content := conv.AsString(streamResponse.Choices[i].Delta.Content)
logger.Debugf(c.Request.Context(), "Original content: %s", content)
// Process the content for <think> tags
cleanContent, reasoningContent, newInThinkTag := processStreamThinkTag(content, inThinkTag, &reasoningBuilder)
inThinkTag = newInThinkTag
// Update content
streamResponse.Choices[i].Delta.Content = cleanContent
// If there's reasoning content, add it to reasoning_content
if reasoningContent != "" {
var reasoningContentAny any = reasoningContent
streamResponse.Choices[i].Delta.ReasoningContent = reasoningContentAny
logger.Debugf(c.Request.Context(), "Setting reasoning_content: %s", reasoningContent)
}
logger.Debugf(c.Request.Context(), "Processed content: clean=%s, reasoning=%s, inThinkTag=%v",
cleanContent, reasoningContent, inThinkTag)
}
}
// Re-marshal the modified response
modifiedData, err := json.Marshal(streamResponse)
if err != nil {
logger.SysError("error marshalling modified stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the original data to client
} else {
modifiedDataStr := dataPrefix + string(modifiedData)
logger.Debugf(c.Request.Context(), "Modified response: %s", modifiedDataStr)
render.StringData(c, modifiedDataStr)
}
// Update responseText with cleaned content
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
@ -116,8 +156,46 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
// Process <think> tags in the response
modified := false
for i := range textResponse.Choices {
if textResponse.Choices[i].Message.Content != nil {
content := textResponse.Choices[i].Message.StringContent()
cleanContent, reasoningContent := extractThinkContent(content)
// If content was modified, update it
if content != cleanContent || reasoningContent != "" {
textResponse.Choices[i].Message.Content = cleanContent
// If there's reasoning content, add it to reasoning_content
if reasoningContent != "" {
// Make sure ReasoningContent is set as a string, not any other type
var reasoningContentAny any = reasoningContent
textResponse.Choices[i].Message.ReasoningContent = reasoningContentAny
}
modified = true
}
}
}
// If the response was modified, re-marshal it
if modified {
modifiedResponseBody, err := json.Marshal(textResponse)
if err != nil {
logger.SysError("error marshalling modified response: " + err.Error())
// If there's an error, use the original response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} else {
// Use the modified response body
responseBody = modifiedResponseBody
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
} else {
// Reset response body with original content
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.

View File

@ -3,6 +3,8 @@ package openai
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
@ -21,3 +23,129 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
StatusCode: statusCode,
}
}
// extractThinkContent extracts the content inside <think> tags and returns the cleaned content and reasoning content.
// The cleaned content is the original content with all <think> tags and their content removed.
// The reasoning content is the concatenation of all content inside <think> tags.
func extractThinkContent(content string) (cleanContent string, reasoningContent string) {
// If content is nil or empty, return as is
if content == "" {
return content, ""
}
// Use regular expression to match <think>...</think> tags
re := regexp.MustCompile(`<think>([\s\S]*?)</think>`)
matches := re.FindAllStringSubmatch(content, -1)
if len(matches) == 0 {
// No <think> tags found, return original content
return content, ""
}
// Extract all content inside <think> tags
var reasoningBuilder strings.Builder
for _, match := range matches {
if len(match) > 1 {
reasoningBuilder.WriteString(match[1])
reasoningBuilder.WriteString("\n")
}
}
// Remove all <think> tags and their content from the original content
cleanContent = re.ReplaceAllString(content, "")
// Fix multiple spaces that might have been created
spaceRe := regexp.MustCompile(`\s+`)
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
// Remove any extra whitespace that might have been created
cleanContent = strings.TrimSpace(cleanContent)
return cleanContent, reasoningBuilder.String()
}
// processStreamThinkTag processes a chunk of content for <think> tags in streaming mode.
// It handles partial tags that may be split across chunks.
// Returns:
// - cleanContent: the content with <think> tags removed
// - reasoningContent: the content inside <think> tags
// - inThinkTag: whether we're currently inside a <think> tag
func processStreamThinkTag(content string, inThinkTag bool, reasoningBuilder *strings.Builder) (cleanContent string, reasoningContent string, stillInThinkTag bool) {
if content == "" {
return content, "", inThinkTag
}
// Initialize reasoningContent as empty
reasoningContent = ""
// Handle case where content contains both <think> and </think> tags
if strings.Contains(content, "<think>") && strings.Contains(content, "</think>") {
// Extract content before <think> tag
beforeThink := strings.Split(content, "<think>")[0]
// Extract content between <think> and </think> tags
betweenTags := strings.Split(strings.Split(content, "<think>")[1], "</think>")[0]
reasoningBuilder.WriteString(betweenTags)
reasoningContent = betweenTags
// Extract content after </think> tag
afterThink := strings.Split(content, "</think>")[1]
// Combine content before and after tags
cleanContent = beforeThink + afterThink
// Fix multiple spaces that might have been created
spaceRe := regexp.MustCompile(`\s+`)
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
// Remove any extra whitespace that might have been created
cleanContent = strings.TrimSpace(cleanContent)
stillInThinkTag = false
return cleanContent, reasoningContent, stillInThinkTag
}
// Handle other cases
switch {
case strings.Contains(content, "<think>"):
stillInThinkTag = true
parts := strings.Split(content, "<think>")
if len(parts) > 0 && parts[0] != "" {
cleanContent = parts[0]
} else {
cleanContent = ""
}
if len(parts) > 1 {
reasoningBuilder.WriteString(parts[1])
reasoningContent = parts[1]
}
case strings.Contains(content, "</think>"):
stillInThinkTag = false
parts := strings.Split(content, "</think>")
if len(parts) > 1 && parts[1] != "" {
cleanContent = parts[1]
} else {
cleanContent = ""
}
if len(parts) > 0 {
reasoningBuilder.WriteString(parts[0])
reasoningContent = parts[0]
}
case inThinkTag:
reasoningBuilder.WriteString(content)
reasoningContent = content
cleanContent = ""
stillInThinkTag = true
default:
cleanContent = content
stillInThinkTag = false
}
return cleanContent, reasoningContent, stillInThinkTag
}

View File

@ -0,0 +1,147 @@
package openai
import (
"strings"
"testing"
)
func TestExtractThinkContent(t *testing.T) {
tests := []struct {
name string
input string
wantCleanContent string
wantReasoningContent string
}{
{
name: "No think tag",
input: "Hello, world!",
wantCleanContent: "Hello, world!",
wantReasoningContent: "",
},
{
name: "Empty input",
input: "",
wantCleanContent: "",
wantReasoningContent: "",
},
{
name: "Single think tag",
input: "Hello, <think>This is reasoning</think> world!",
wantCleanContent: "Hello, world!",
wantReasoningContent: "This is reasoning\n",
},
{
name: "Multiple think tags",
input: "<think>First reasoning</think>Hello, <think>Second reasoning</think> world!",
wantCleanContent: "Hello, world!",
wantReasoningContent: "First reasoning\nSecond reasoning\n",
},
{
name: "Think tag with newlines",
input: "Hello, <think>This is\nmulti-line\nreasoning</think> world!",
wantCleanContent: "Hello, world!",
wantReasoningContent: "This is\nmulti-line\nreasoning\n",
},
{
name: "Only think tag",
input: "<think>Only reasoning</think>",
wantCleanContent: "",
wantReasoningContent: "Only reasoning\n",
},
{
name: "Incomplete think tag (should be ignored)",
input: "Hello, <think>Incomplete reasoning world!",
wantCleanContent: "Hello, <think>Incomplete reasoning world!",
wantReasoningContent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotCleanContent, gotReasoningContent := extractThinkContent(tt.input)
if gotCleanContent != tt.wantCleanContent {
t.Errorf("extractThinkContent() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("extractThinkContent() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
})
}
}
func TestProcessStreamThinkTag(t *testing.T) {
tests := []struct {
name string
content string
initialInThinkTag bool
wantCleanContent string
wantReasoningContent string
wantStillInThinkTag bool
}{
{
name: "Empty content",
content: "",
initialInThinkTag: false,
wantCleanContent: "",
wantReasoningContent: "",
wantStillInThinkTag: false,
},
{
name: "Regular content",
content: "Hello, world!",
initialInThinkTag: false,
wantCleanContent: "Hello, world!",
wantReasoningContent: "",
wantStillInThinkTag: false,
},
{
name: "Content with <think> tag",
content: "Hello, <think>reasoning",
initialInThinkTag: false,
wantCleanContent: "Hello, ",
wantReasoningContent: "reasoning",
wantStillInThinkTag: true,
},
{
name: "Content with </think> tag",
content: "reasoning</think> world!",
initialInThinkTag: true,
wantCleanContent: " world!",
wantReasoningContent: "reasoning",
wantStillInThinkTag: false,
},
{
name: "Content inside <think> tag",
content: "reasoning content",
initialInThinkTag: true,
wantCleanContent: "",
wantReasoningContent: "reasoning content",
wantStillInThinkTag: true,
},
{
name: "Content with both <think> and </think> tags",
content: "Hello, <think>reasoning</think> world!",
initialInThinkTag: false,
wantCleanContent: "Hello, world!",
wantReasoningContent: "reasoning",
wantStillInThinkTag: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var reasoningBuilder strings.Builder
gotCleanContent, gotReasoningContent, gotStillInThinkTag := processStreamThinkTag(tt.content, tt.initialInThinkTag, &reasoningBuilder)
if gotCleanContent != tt.wantCleanContent {
t.Errorf("processStreamThinkTag() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("processStreamThinkTag() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
if gotStillInThinkTag != tt.wantStillInThinkTag {
t.Errorf("processStreamThinkTag() gotStillInThinkTag = %v, want %v", gotStillInThinkTag, tt.wantStillInThinkTag)
}
})
}
}