mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
Add support for processing <think> tags in both streaming and non-streaming responses
This commit is contained in:
parent
8df4a2670b
commit
768a1be11a
@ -32,6 +32,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
|
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
// Variables to track <think> tag state across chunks
|
||||||
|
inThinkTag := false
|
||||||
|
var reasoningBuilder strings.Builder
|
||||||
|
|
||||||
doneRendered := false
|
doneRendered := false
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@ -52,14 +56,50 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
|
|||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
render.StringData(c, data) // if error happened, pass the data to client
|
render.StringData(c, data)
|
||||||
continue // just ignore the error
|
continue
|
||||||
}
|
}
|
||||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
continue
|
||||||
continue // just ignore empty choice
|
|
||||||
}
|
}
|
||||||
render.StringData(c, data)
|
|
||||||
|
// Process <think> tags before rendering
|
||||||
|
for i := range streamResponse.Choices {
|
||||||
|
if streamResponse.Choices[i].Delta.Content != nil {
|
||||||
|
content := conv.AsString(streamResponse.Choices[i].Delta.Content)
|
||||||
|
logger.Debugf(c.Request.Context(), "Original content: %s", content)
|
||||||
|
|
||||||
|
// Process the content for <think> tags
|
||||||
|
cleanContent, reasoningContent, newInThinkTag := processStreamThinkTag(content, inThinkTag, &reasoningBuilder)
|
||||||
|
inThinkTag = newInThinkTag
|
||||||
|
|
||||||
|
// Update content
|
||||||
|
streamResponse.Choices[i].Delta.Content = cleanContent
|
||||||
|
|
||||||
|
// If there's reasoning content, add it to reasoning_content
|
||||||
|
if reasoningContent != "" {
|
||||||
|
var reasoningContentAny any = reasoningContent
|
||||||
|
streamResponse.Choices[i].Delta.ReasoningContent = reasoningContentAny
|
||||||
|
logger.Debugf(c.Request.Context(), "Setting reasoning_content: %s", reasoningContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf(c.Request.Context(), "Processed content: clean=%s, reasoning=%s, inThinkTag=%v",
|
||||||
|
cleanContent, reasoningContent, inThinkTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-marshal the modified response
|
||||||
|
modifiedData, err := json.Marshal(streamResponse)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error marshalling modified stream response: " + err.Error())
|
||||||
|
render.StringData(c, data) // if error happened, pass the original data to client
|
||||||
|
} else {
|
||||||
|
modifiedDataStr := dataPrefix + string(modifiedData)
|
||||||
|
logger.Debugf(c.Request.Context(), "Modified response: %s", modifiedDataStr)
|
||||||
|
render.StringData(c, modifiedDataStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update responseText with cleaned content
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseText += conv.AsString(choice.Delta.Content)
|
responseText += conv.AsString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
@ -116,8 +156,46 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
// Process <think> tags in the response
|
||||||
|
modified := false
|
||||||
|
for i := range textResponse.Choices {
|
||||||
|
if textResponse.Choices[i].Message.Content != nil {
|
||||||
|
content := textResponse.Choices[i].Message.StringContent()
|
||||||
|
cleanContent, reasoningContent := extractThinkContent(content)
|
||||||
|
|
||||||
|
// If content was modified, update it
|
||||||
|
if content != cleanContent || reasoningContent != "" {
|
||||||
|
textResponse.Choices[i].Message.Content = cleanContent
|
||||||
|
|
||||||
|
// If there's reasoning content, add it to reasoning_content
|
||||||
|
if reasoningContent != "" {
|
||||||
|
// Make sure ReasoningContent is set as a string, not any other type
|
||||||
|
var reasoningContentAny any = reasoningContent
|
||||||
|
textResponse.Choices[i].Message.ReasoningContent = reasoningContentAny
|
||||||
|
}
|
||||||
|
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the response was modified, re-marshal it
|
||||||
|
if modified {
|
||||||
|
modifiedResponseBody, err := json.Marshal(textResponse)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("error marshalling modified response: " + err.Error())
|
||||||
|
// If there's an error, use the original response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
} else {
|
||||||
|
// Use the modified response body
|
||||||
|
responseBody = modifiedResponseBody
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Reset response body with original content
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
|
@ -3,6 +3,8 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
@ -21,3 +23,129 @@ func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatus
|
|||||||
StatusCode: statusCode,
|
StatusCode: statusCode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractThinkContent extracts the content inside <think> tags and returns the cleaned content and reasoning content.
|
||||||
|
// The cleaned content is the original content with all <think> tags and their content removed.
|
||||||
|
// The reasoning content is the concatenation of all content inside <think> tags.
|
||||||
|
func extractThinkContent(content string) (cleanContent string, reasoningContent string) {
|
||||||
|
// If content is nil or empty, return as is
|
||||||
|
if content == "" {
|
||||||
|
return content, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use regular expression to match <think>...</think> tags
|
||||||
|
re := regexp.MustCompile(`<think>([\s\S]*?)</think>`)
|
||||||
|
matches := re.FindAllStringSubmatch(content, -1)
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
// No <think> tags found, return original content
|
||||||
|
return content, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract all content inside <think> tags
|
||||||
|
var reasoningBuilder strings.Builder
|
||||||
|
for _, match := range matches {
|
||||||
|
if len(match) > 1 {
|
||||||
|
reasoningBuilder.WriteString(match[1])
|
||||||
|
reasoningBuilder.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove all <think> tags and their content from the original content
|
||||||
|
cleanContent = re.ReplaceAllString(content, "")
|
||||||
|
|
||||||
|
// Fix multiple spaces that might have been created
|
||||||
|
spaceRe := regexp.MustCompile(`\s+`)
|
||||||
|
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
|
||||||
|
|
||||||
|
// Remove any extra whitespace that might have been created
|
||||||
|
cleanContent = strings.TrimSpace(cleanContent)
|
||||||
|
|
||||||
|
return cleanContent, reasoningBuilder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// processStreamThinkTag processes a chunk of content for <think> tags in streaming mode.
|
||||||
|
// It handles partial tags that may be split across chunks.
|
||||||
|
// Returns:
|
||||||
|
// - cleanContent: the content with <think> tags removed
|
||||||
|
// - reasoningContent: the content inside <think> tags
|
||||||
|
// - inThinkTag: whether we're currently inside a <think> tag
|
||||||
|
func processStreamThinkTag(content string, inThinkTag bool, reasoningBuilder *strings.Builder) (cleanContent string, reasoningContent string, stillInThinkTag bool) {
|
||||||
|
if content == "" {
|
||||||
|
return content, "", inThinkTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize reasoningContent as empty
|
||||||
|
reasoningContent = ""
|
||||||
|
|
||||||
|
// Handle case where content contains both <think> and </think> tags
|
||||||
|
if strings.Contains(content, "<think>") && strings.Contains(content, "</think>") {
|
||||||
|
// Extract content before <think> tag
|
||||||
|
beforeThink := strings.Split(content, "<think>")[0]
|
||||||
|
|
||||||
|
// Extract content between <think> and </think> tags
|
||||||
|
betweenTags := strings.Split(strings.Split(content, "<think>")[1], "</think>")[0]
|
||||||
|
reasoningBuilder.WriteString(betweenTags)
|
||||||
|
reasoningContent = betweenTags
|
||||||
|
|
||||||
|
// Extract content after </think> tag
|
||||||
|
afterThink := strings.Split(content, "</think>")[1]
|
||||||
|
|
||||||
|
// Combine content before and after tags
|
||||||
|
cleanContent = beforeThink + afterThink
|
||||||
|
|
||||||
|
// Fix multiple spaces that might have been created
|
||||||
|
spaceRe := regexp.MustCompile(`\s+`)
|
||||||
|
cleanContent = spaceRe.ReplaceAllString(cleanContent, " ")
|
||||||
|
|
||||||
|
// Remove any extra whitespace that might have been created
|
||||||
|
cleanContent = strings.TrimSpace(cleanContent)
|
||||||
|
|
||||||
|
stillInThinkTag = false
|
||||||
|
|
||||||
|
return cleanContent, reasoningContent, stillInThinkTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle other cases
|
||||||
|
switch {
|
||||||
|
case strings.Contains(content, "<think>"):
|
||||||
|
stillInThinkTag = true
|
||||||
|
parts := strings.Split(content, "<think>")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
cleanContent = parts[0]
|
||||||
|
} else {
|
||||||
|
cleanContent = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 1 {
|
||||||
|
reasoningBuilder.WriteString(parts[1])
|
||||||
|
reasoningContent = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
case strings.Contains(content, "</think>"):
|
||||||
|
stillInThinkTag = false
|
||||||
|
parts := strings.Split(content, "</think>")
|
||||||
|
if len(parts) > 1 && parts[1] != "" {
|
||||||
|
cleanContent = parts[1]
|
||||||
|
} else {
|
||||||
|
cleanContent = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 0 {
|
||||||
|
reasoningBuilder.WriteString(parts[0])
|
||||||
|
reasoningContent = parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
case inThinkTag:
|
||||||
|
reasoningBuilder.WriteString(content)
|
||||||
|
reasoningContent = content
|
||||||
|
cleanContent = ""
|
||||||
|
stillInThinkTag = true
|
||||||
|
|
||||||
|
default:
|
||||||
|
cleanContent = content
|
||||||
|
stillInThinkTag = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanContent, reasoningContent, stillInThinkTag
|
||||||
|
}
|
||||||
|
147
relay/adaptor/openai/util_test.go
Normal file
147
relay/adaptor/openai/util_test.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractThinkContent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantCleanContent string
|
||||||
|
wantReasoningContent string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No think tag",
|
||||||
|
input: "Hello, world!",
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty input",
|
||||||
|
input: "",
|
||||||
|
wantCleanContent: "",
|
||||||
|
wantReasoningContent: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single think tag",
|
||||||
|
input: "Hello, <think>This is reasoning</think> world!",
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "This is reasoning\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple think tags",
|
||||||
|
input: "<think>First reasoning</think>Hello, <think>Second reasoning</think> world!",
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "First reasoning\nSecond reasoning\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Think tag with newlines",
|
||||||
|
input: "Hello, <think>This is\nmulti-line\nreasoning</think> world!",
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "This is\nmulti-line\nreasoning\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Only think tag",
|
||||||
|
input: "<think>Only reasoning</think>",
|
||||||
|
wantCleanContent: "",
|
||||||
|
wantReasoningContent: "Only reasoning\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Incomplete think tag (should be ignored)",
|
||||||
|
input: "Hello, <think>Incomplete reasoning world!",
|
||||||
|
wantCleanContent: "Hello, <think>Incomplete reasoning world!",
|
||||||
|
wantReasoningContent: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotCleanContent, gotReasoningContent := extractThinkContent(tt.input)
|
||||||
|
if gotCleanContent != tt.wantCleanContent {
|
||||||
|
t.Errorf("extractThinkContent() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
|
||||||
|
}
|
||||||
|
if gotReasoningContent != tt.wantReasoningContent {
|
||||||
|
t.Errorf("extractThinkContent() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessStreamThinkTag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
initialInThinkTag bool
|
||||||
|
wantCleanContent string
|
||||||
|
wantReasoningContent string
|
||||||
|
wantStillInThinkTag bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty content",
|
||||||
|
content: "",
|
||||||
|
initialInThinkTag: false,
|
||||||
|
wantCleanContent: "",
|
||||||
|
wantReasoningContent: "",
|
||||||
|
wantStillInThinkTag: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Regular content",
|
||||||
|
content: "Hello, world!",
|
||||||
|
initialInThinkTag: false,
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "",
|
||||||
|
wantStillInThinkTag: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Content with <think> tag",
|
||||||
|
content: "Hello, <think>reasoning",
|
||||||
|
initialInThinkTag: false,
|
||||||
|
wantCleanContent: "Hello, ",
|
||||||
|
wantReasoningContent: "reasoning",
|
||||||
|
wantStillInThinkTag: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Content with </think> tag",
|
||||||
|
content: "reasoning</think> world!",
|
||||||
|
initialInThinkTag: true,
|
||||||
|
wantCleanContent: " world!",
|
||||||
|
wantReasoningContent: "reasoning",
|
||||||
|
wantStillInThinkTag: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Content inside <think> tag",
|
||||||
|
content: "reasoning content",
|
||||||
|
initialInThinkTag: true,
|
||||||
|
wantCleanContent: "",
|
||||||
|
wantReasoningContent: "reasoning content",
|
||||||
|
wantStillInThinkTag: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Content with both <think> and </think> tags",
|
||||||
|
content: "Hello, <think>reasoning</think> world!",
|
||||||
|
initialInThinkTag: false,
|
||||||
|
wantCleanContent: "Hello, world!",
|
||||||
|
wantReasoningContent: "reasoning",
|
||||||
|
wantStillInThinkTag: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var reasoningBuilder strings.Builder
|
||||||
|
gotCleanContent, gotReasoningContent, gotStillInThinkTag := processStreamThinkTag(tt.content, tt.initialInThinkTag, &reasoningBuilder)
|
||||||
|
|
||||||
|
if gotCleanContent != tt.wantCleanContent {
|
||||||
|
t.Errorf("processStreamThinkTag() gotCleanContent = %q, want %q", gotCleanContent, tt.wantCleanContent)
|
||||||
|
}
|
||||||
|
if gotReasoningContent != tt.wantReasoningContent {
|
||||||
|
t.Errorf("processStreamThinkTag() gotReasoningContent = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||||
|
}
|
||||||
|
if gotStillInThinkTag != tt.wantStillInThinkTag {
|
||||||
|
t.Errorf("processStreamThinkTag() gotStillInThinkTag = %v, want %v", gotStillInThinkTag, tt.wantStillInThinkTag)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user