mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 01:06:37 +08:00
Add support for processing <think> tags in both streaming and non-streaming responses
This commit is contained in:
parent
8df4a2670b
commit
768a1be11a
@ -32,6 +32,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
||||
|
||||
common.SetEventStreamHeaders(c)
|
||||
|
||||
// Variables to track <think> tag state across chunks
|
||||
inThinkTag := false
|
||||
var reasoningBuilder strings.Builder
|
||||
|
||||
doneRendered := false
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
@ -52,14 +56,50 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||
render.StringData(c, data) // if error happened, pass the data to client
|
||||
continue // just ignore the error
|
||||
render.StringData(c, data)
|
||||
continue
|
||||
}
|
||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
||||
continue // just ignore empty choice
|
||||
continue
|
||||
}
|
||||
render.StringData(c, data)
|
||||
|
||||
// Process <think> tags before rendering
|
||||
for i := range streamResponse.Choices {
|
||||
if streamResponse.Choices[i].Delta.Content != nil {
|
||||
content := conv.AsString(streamResponse.Choices[i].Delta.Content)
|
||||
logger.Debugf(c.Request.Context(), "Original content: %s", content)
|
||||
|
||||
// Process the content for <think> tags
|
||||
cleanContent, reasoningContent, newInThinkTag := processStreamThinkTag(content, inThinkTag, &reasoningBuilder)
|
||||
inThinkTag = newInThinkTag
|
||||
|
||||
// Update content
|
||||
streamResponse.Choices[i].Delta.Content = cleanContent
|
||||
|
||||
// If there's reasoning content, add it to reasoning_content
|
||||
if reasoningContent != "" {
|
||||
var reasoningContentAny any = reasoningContent
|
||||
streamResponse.Choices[i].Delta.ReasoningContent = reasoningContentAny
|
||||
logger.Debugf(c.Request.Context(), "Setting reasoning_content: %s", reasoningContent)
|
||||
}
|
||||
|
||||
logger.Debugf(c.Request.Context(), "Processed content: clean=%s, reasoning=%s, inThinkTag=%v",
|
||||
cleanContent, reasoningContent, inThinkTag)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-marshal the modified response
|
||||
modifiedData, err := json.Marshal(streamResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling modified stream response: " + err.Error())
|
||||
render.StringData(c, data) // if error happened, pass the original data to client
|
||||
} else {
|
||||
modifiedDataStr := dataPrefix + string(modifiedData)
|
||||
logger.Debugf(c.Request.Context(), "Modified response: %s", modifiedDataStr)
|
||||
render.StringData(c, modifiedDataStr)
|
||||
}
|
||||
|
||||
// Update responseText with cleaned content
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += conv.AsString(choice.Delta.Content)
|
||||
}
|
||||
@ -116,8 +156,46 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// Process <think> tags in the response
|
||||
modified := false
|
||||
for i := range textResponse.Choices {
|
||||
if textResponse.Choices[i].Message.Content != nil {
|
||||
content := textResponse.Choices[i].Message.StringContent()
|
||||
cleanContent, reasoningContent := extractThinkContent(content)
|
||||
|
||||
// If content was modified, update it
|
||||
if content != cleanContent || reasoningContent != "" {
|
||||
textResponse.Choices[i].Message.Content = cleanContent
|
||||
|
||||
// If there's reasoning content, add it to reasoning_content
|
||||
if reasoningContent != "" {
|
||||
// Make sure ReasoningContent is set as a string, not any other type
|
||||
var reasoningContentAny any = reasoningContent
|
||||
textResponse.Choices[i].Message.ReasoningContent = reasoningContentAny
|
||||
}
|
||||
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the response was modified, re-marshal it
|
||||
if modified {
|
||||
modifiedResponseBody, err := json.Marshal(textResponse)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling modified response: " + err.Error())
|
||||
// If there's an error, use the original response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
} else {
|
||||
// Use the modified response body
|
||||
responseBody = modifiedResponseBody
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
} else {
|
||||
// Reset response body with original content
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
|
@ -3,6 +3,8 @@ package openai
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
@ -21,3 +23,129 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// extractThinkContent extracts the content inside <think> tags and returns the cleaned content and reasoning content.
|
||||
// The cleaned content is the original content with all <think> tags and their content removed.
|
||||
// The reasoning content is the concatenation of all content inside <think> tags.
|
||||
func extractThinkContent(content string) (cleanContent string, reasoningContent string) {
|
||||
// If content is nil or empty, return as is
|
||||
if content == "" {
|
||||
return content, ""
|
||||
}
|
||||
|
||||
// Use regular expression to match <think>...</think> tags
|
||||
re := regexp.MustCompile(`<think>([\s\S]*?)</think>`)
|
||||
matches := re.FindAllStringSubmatch(content, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
// No <think> tags found, return original content
|
||||
return content, ""
|
||||
}
|
||||
|
||||
// Extract all content inside <think> tags
|
||||
var reasoningBuilder strings.Builder
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
reasoningBuilder.WriteString(match[1])
|
||||
reasoningBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all <think> tags and their content from the original content
|
||||
cleanContent = re.ReplaceAllString(content, "")
|
||||
|
||||
// Fix multiple spaces that might have been created
|
||||
spaceRe := regexp.MustCompile(`\s+`)
|
||||
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
|
||||
|
||||
// Remove any extra whitespace that might have been created
|
||||
cleanContent = strings.TrimSpace(cleanContent)
|
||||
|
||||
return cleanContent, reasoningBuilder.String()
|
||||
}
|
||||
|
||||
// processStreamThinkTag processes a chunk of content for <think> tags in streaming mode.
|
||||
// It handles partial tags that may be split across chunks.
|
||||
// Returns:
|
||||
// - cleanContent: the content with <think> tags removed
|
||||
// - reasoningContent: the content inside <think> tags
|
||||
// - inThinkTag: whether we're currently inside a <think> tag
|
||||
func processStreamThinkTag(content string, inThinkTag bool, reasoningBuilder *strings.Builder) (cleanContent string, reasoningContent string, stillInThinkTag bool) {
|
||||
if content == "" {
|
||||
return content, "", inThinkTag
|
||||
}
|
||||
|
||||
// Initialize reasoningContent as empty
|
||||
reasoningContent = ""
|
||||
|
||||
// Handle case where content contains both <think> and </think> tags
|
||||
if strings.Contains(content, "<think>") && strings.Contains(content, "</think>") {
|
||||
// Extract content before <think> tag
|
||||
beforeThink := strings.Split(content, "<think>")[0]
|
||||
|
||||
// Extract content between <think> and </think> tags
|
||||
betweenTags := strings.Split(strings.Split(content, "<think>")[1], "</think>")[0]
|
||||
reasoningBuilder.WriteString(betweenTags)
|
||||
reasoningContent = betweenTags
|
||||
|
||||
// Extract content after </think> tag
|
||||
afterThink := strings.Split(content, "</think>")[1]
|
||||
|
||||
// Combine content before and after tags
|
||||
cleanContent = beforeThink + afterThink
|
||||
|
||||
// Fix multiple spaces that might have been created
|
||||
spaceRe := regexp.MustCompile(`\s+`)
|
||||
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
|
||||
|
||||
// Remove any extra whitespace that might have been created
|
||||
cleanContent = strings.TrimSpace(cleanContent)
|
||||
|
||||
stillInThinkTag = false
|
||||
|
||||
return cleanContent, reasoningContent, stillInThinkTag
|
||||
}
|
||||
|
||||
// Handle other cases
|
||||
switch {
|
||||
case strings.Contains(content, "<think>"):
|
||||
stillInThinkTag = true
|
||||
parts := strings.Split(content, "<think>")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
cleanContent = parts[0]
|
||||
} else {
|
||||
cleanContent = ""
|
||||
}
|
||||
|
||||
if len(parts) > 1 {
|
||||
reasoningBuilder.WriteString(parts[1])
|
||||
reasoningContent = parts[1]
|
||||
}
|
||||
|
||||
case strings.Contains(content, "</think>"):
|
||||
stillInThinkTag = false
|
||||
parts := strings.Split(content, "</think>")
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
cleanContent = parts[1]
|
||||
} else {
|
||||
cleanContent = ""
|
||||
}
|
||||
|
||||
if len(parts) > 0 {
|
||||
reasoningBuilder.WriteString(parts[0])
|
||||
reasoningContent = parts[0]
|
||||
}
|
||||
|
||||
case inThinkTag:
|
||||
reasoningBuilder.WriteString(content)
|
||||
reasoningContent = content
|
||||
cleanContent = ""
|
||||
stillInThinkTag = true
|
||||
|
||||
default:
|
||||
cleanContent = content
|
||||
stillInThinkTag = false
|
||||
}
|
||||
|
||||
return cleanContent, reasoningContent, stillInThinkTag
|
||||
}
|
||||
|
147
relay/adaptor/openai/util_test.go
Normal file
147
relay/adaptor/openai/util_test.go
Normal file
@ -0,0 +1,147 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractThinkContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantCleanContent string
|
||||
wantReasoningContent string
|
||||
}{
|
||||
{
|
||||
name: "No think tag",
|
||||
input: "Hello, world!",
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "Empty input",
|
||||
input: "",
|
||||
wantCleanContent: "",
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "Single think tag",
|
||||
input: "Hello, <think>This is reasoning</think> world!",
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "This is reasoning\n",
|
||||
},
|
||||
{
|
||||
name: "Multiple think tags",
|
||||
input: "<think>First reasoning</think>Hello, <think>Second reasoning</think> world!",
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "First reasoning\nSecond reasoning\n",
|
||||
},
|
||||
{
|
||||
name: "Think tag with newlines",
|
||||
input: "Hello, <think>This is\nmulti-line\nreasoning</think> world!",
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "This is\nmulti-line\nreasoning\n",
|
||||
},
|
||||
{
|
||||
name: "Only think tag",
|
||||
input: "<think>Only reasoning</think>",
|
||||
wantCleanContent: "",
|
||||
wantReasoningContent: "Only reasoning\n",
|
||||
},
|
||||
{
|
||||
name: "Incomplete think tag (should be ignored)",
|
||||
input: "Hello, <think>Incomplete reasoning world!",
|
||||
wantCleanContent: "Hello, <think>Incomplete reasoning world!",
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotCleanContent, gotReasoningContent := extractThinkContent(tt.input)
|
||||
if gotCleanContent != tt.wantCleanContent {
|
||||
t.Errorf("extractThinkContent() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
|
||||
}
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("extractThinkContent() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessStreamThinkTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
initialInThinkTag bool
|
||||
wantCleanContent string
|
||||
wantReasoningContent string
|
||||
wantStillInThinkTag bool
|
||||
}{
|
||||
{
|
||||
name: "Empty content",
|
||||
content: "",
|
||||
initialInThinkTag: false,
|
||||
wantCleanContent: "",
|
||||
wantReasoningContent: "",
|
||||
wantStillInThinkTag: false,
|
||||
},
|
||||
{
|
||||
name: "Regular content",
|
||||
content: "Hello, world!",
|
||||
initialInThinkTag: false,
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "",
|
||||
wantStillInThinkTag: false,
|
||||
},
|
||||
{
|
||||
name: "Content with <think> tag",
|
||||
content: "Hello, <think>reasoning",
|
||||
initialInThinkTag: false,
|
||||
wantCleanContent: "Hello, ",
|
||||
wantReasoningContent: "reasoning",
|
||||
wantStillInThinkTag: true,
|
||||
},
|
||||
{
|
||||
name: "Content with </think> tag",
|
||||
content: "reasoning</think> world!",
|
||||
initialInThinkTag: true,
|
||||
wantCleanContent: " world!",
|
||||
wantReasoningContent: "reasoning",
|
||||
wantStillInThinkTag: false,
|
||||
},
|
||||
{
|
||||
name: "Content inside <think> tag",
|
||||
content: "reasoning content",
|
||||
initialInThinkTag: true,
|
||||
wantCleanContent: "",
|
||||
wantReasoningContent: "reasoning content",
|
||||
wantStillInThinkTag: true,
|
||||
},
|
||||
{
|
||||
name: "Content with both <think> and </think> tags",
|
||||
content: "Hello, <think>reasoning</think> world!",
|
||||
initialInThinkTag: false,
|
||||
wantCleanContent: "Hello, world!",
|
||||
wantReasoningContent: "reasoning",
|
||||
wantStillInThinkTag: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var reasoningBuilder strings.Builder
|
||||
gotCleanContent, gotReasoningContent, gotStillInThinkTag := processStreamThinkTag(tt.content, tt.initialInThinkTag, &reasoningBuilder)
|
||||
|
||||
if gotCleanContent != tt.wantCleanContent {
|
||||
t.Errorf("processStreamThinkTag() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
|
||||
}
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("processStreamThinkTag() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
if gotStillInThinkTag != tt.wantStillInThinkTag {
|
||||
t.Errorf("processStreamThinkTag() gotStillInThinkTag = %v, want %v", gotStillInThinkTag, tt.wantStillInThinkTag)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user