mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	refactor: abusing goroutines and channel (#1561)
* refactor: abusing goroutines * fix: trim data prefix * refactor: move functions to render package * refactor: add back trim & flush --------- Co-authored-by: JustSong <quanpengsong@gmail.com>
This commit is contained in:
		
							
								
								
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								common/render/render.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
package render
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StringData(c *gin.Context, str string) {
 | 
			
		||||
	str = strings.TrimPrefix(str, "data: ")
 | 
			
		||||
	str = strings.TrimSuffix(str, "\r")
 | 
			
		||||
	c.Render(-1, common.CustomEvent{Data: "data: " + str})
 | 
			
		||||
	c.Writer.Flush()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ObjectData(c *gin.Context, object interface{}) error {
 | 
			
		||||
	jsonData, err := json.Marshal(object)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error marshalling object: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	StringData(c, string(jsonData))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Done(c *gin.Context) {
 | 
			
		||||
	StringData(c, "[DONE]")
 | 
			
		||||
}
 | 
			
		||||
@@ -4,6 +4,12 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -12,10 +18,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
			
		||||
@@ -89,6 +91,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
 | 
			
		||||
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	var documents []LibraryDocument
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
@@ -102,60 +105,48 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var documents []LibraryDocument
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var AIProxyLibraryResponse LibraryStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
				documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			response := documentsAIProxyLibrary(documents)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || data[:5] != "data:" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
		data = data[5:]
 | 
			
		||||
 | 
			
		||||
		var AIProxyLibraryResponse LibraryStreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
			documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	response := documentsAIProxyLibrary(documents)
 | 
			
		||||
	err := render.ObjectData(c, response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,15 +3,17 @@ package ali
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
 | 
			
		||||
@@ -181,56 +183,43 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	//lastResponseText := ""
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var aliResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &aliResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if aliResponse.Usage.OutputTokens != 0 {
 | 
			
		||||
				usage.PromptTokens = aliResponse.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens = aliResponse.Usage.OutputTokens
 | 
			
		||||
				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 | 
			
		||||
			//lastResponseText = aliResponse.Output.Text
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || data[:5] != "data:" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = data[5:]
 | 
			
		||||
 | 
			
		||||
		var aliResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &aliResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if aliResponse.Usage.OutputTokens != 0 {
 | 
			
		||||
			usage.PromptTokens = aliResponse.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens = aliResponse.Usage.OutputTokens
 | 
			
		||||
			usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -169,64 +170,59 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.HasPrefix(data, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	var modelName string
 | 
			
		||||
	var id string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSpace(data)
 | 
			
		||||
			var claudeResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
				modelName = meta.Model
 | 
			
		||||
				id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = modelName
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 6 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
		data = strings.TrimSpace(data)
 | 
			
		||||
 | 
			
		||||
		var claudeResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &claudeResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
 | 
			
		||||
		if meta != nil {
 | 
			
		||||
			usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
			modelName = meta.Model
 | 
			
		||||
			id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response.Id = id
 | 
			
		||||
		response.Model = modelName
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,13 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/client"
 | 
			
		||||
@@ -12,11 +19,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 | 
			
		||||
@@ -137,59 +139,41 @@ func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.Embeddin
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var baiduResponse ChatStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if baiduResponse.Usage.TotalTokens != 0 {
 | 
			
		||||
				usage.TotalTokens = baiduResponse.Usage.TotalTokens
 | 
			
		||||
				usage.PromptTokens = baiduResponse.Usage.PromptTokens
 | 
			
		||||
				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 6 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = data[6:]
 | 
			
		||||
 | 
			
		||||
		var baiduResponse ChatStreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &baiduResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if baiduResponse.Usage.TotalTokens != 0 {
 | 
			
		||||
			usage.TotalTokens = baiduResponse.Usage.TotalTokens
 | 
			
		||||
			usage.PromptTokens = baiduResponse.Usage.PromptTokens
 | 
			
		||||
			usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
 | 
			
		||||
		}
 | 
			
		||||
		response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
 
 | 
			
		||||
@@ -2,8 +2,8 @@ package cloudflare
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -17,21 +17,20 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
    var promptBuilder strings.Builder
 | 
			
		||||
    for _, message := range textRequest.Messages {
 | 
			
		||||
        promptBuilder.WriteString(message.StringContent())
 | 
			
		||||
        promptBuilder.WriteString("\n")  // 添加换行符来分隔每个消息
 | 
			
		||||
    }
 | 
			
		||||
	var promptBuilder strings.Builder
 | 
			
		||||
	for _, message := range textRequest.Messages {
 | 
			
		||||
		promptBuilder.WriteString(message.StringContent())
 | 
			
		||||
		promptBuilder.WriteString("\n") // 添加换行符来分隔每个消息
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
    return &Request{
 | 
			
		||||
        MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
        Prompt:      promptBuilder.String(),
 | 
			
		||||
        Stream:      textRequest.Stream,
 | 
			
		||||
        Temperature: textRequest.Temperature,
 | 
			
		||||
    }
 | 
			
		||||
	return &Request{
 | 
			
		||||
		MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
		Prompt:      promptBuilder.String(),
 | 
			
		||||
		Stream:      textRequest.Stream,
 | 
			
		||||
		Temperature: textRequest.Temperature,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse {
 | 
			
		||||
	choice := openai.TextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
@@ -63,67 +62,54 @@ func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai
 | 
			
		||||
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := bytes.IndexByte(data, '\n'); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < len("data: ") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	id := helper.GetResponseID(c)
 | 
			
		||||
	responseModel := c.GetString("original_model")
 | 
			
		||||
	var responseText string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cloudflareResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cloudflareResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			responseText += cloudflareResponse.Response
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = responseModel
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < len("data: ") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cloudflareResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cloudflareResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := StreamResponseCloudflare2OpenAI(&cloudflareResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		responseText += cloudflareResponse.Response
 | 
			
		||||
		response.Id = id
 | 
			
		||||
		response.Model = responseModel
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens)
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,9 +2,9 @@ package cohere
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -134,66 +134,53 @@ func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse {
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := bytes.IndexByte(data, '\n'); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cohereResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cohereResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Meta.Tokens.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
 | 
			
		||||
			response.Model = c.GetString("original_model")
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cohereResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cohereResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
 | 
			
		||||
		response, meta := StreamResponseCohere2OpenAI(&cohereResponse)
 | 
			
		||||
		if meta != nil {
 | 
			
		||||
			usage.PromptTokens += meta.Meta.Tokens.InputTokens
 | 
			
		||||
			usage.CompletionTokens += meta.Meta.Tokens.OutputTokens
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response.Id = fmt.Sprintf("chatcmpl-%d", createdTime)
 | 
			
		||||
		response.Model = c.GetString("original_model")
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,11 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
@@ -12,9 +17,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://www.coze.com/open
 | 
			
		||||
@@ -109,69 +111,54 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	var responseText string
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !strings.HasPrefix(data, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	var modelName string
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			var cozeResponse StreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &cozeResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			for _, choice := range response.Choices {
 | 
			
		||||
				responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			response.Model = modelName
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	_ = resp.Body.Close()
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
 | 
			
		||||
		var cozeResponse StreamResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &cozeResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response, _ := StreamResponseCoze2OpenAI(&cozeResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, choice := range response.Choices {
 | 
			
		||||
			responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
		}
 | 
			
		||||
		response.Model = modelName
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -275,64 +276,50 @@ func embeddingResponseGemini2OpenAI(response *EmbeddingResponse) *openai.Embeddi
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			data = strings.TrimSpace(data)
 | 
			
		||||
			if !strings.HasPrefix(data, "data: ") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
			data = strings.TrimSuffix(data, "\"")
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var geminiResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &geminiResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseGeminiChat2OpenAI(&geminiResponse)
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			responseText += response.Choices[0].Delta.StringContent()
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		data = strings.TrimSpace(data)
 | 
			
		||||
		if !strings.HasPrefix(data, "data: ") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = strings.TrimPrefix(data, "data: ")
 | 
			
		||||
		data = strings.TrimSuffix(data, "\"")
 | 
			
		||||
 | 
			
		||||
		var geminiResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &geminiResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseGeminiChat2OpenAI(&geminiResponse)
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		responseText += response.Choices[0].Delta.StringContent()
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,12 +5,14 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/image"
 | 
			
		||||
@@ -105,54 +107,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "}\n"); i >= 0 {
 | 
			
		||||
			return i + 2, data[0:i], nil
 | 
			
		||||
			return i + 2, data[0 : i+1], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
			dataChan <- data + "}"
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var ollamaResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if ollamaResponse.EvalCount != 0 {
 | 
			
		||||
				usage.PromptTokens = ollamaResponse.PromptEvalCount
 | 
			
		||||
				usage.CompletionTokens = ollamaResponse.EvalCount
 | 
			
		||||
				usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := strings.TrimPrefix(scanner.Text(), "}")
 | 
			
		||||
		data = data + "}"
 | 
			
		||||
 | 
			
		||||
		var ollamaResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &ollamaResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
		if ollamaResponse.EvalCount != 0 {
 | 
			
		||||
			usage.PromptTokens = ollamaResponse.PromptEvalCount
 | 
			
		||||
			usage.CompletionTokens = ollamaResponse.EvalCount
 | 
			
		||||
			usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseOllama2OpenAI(&ollamaResponse)
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -25,88 +26,68 @@ const (
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
	var usage *model.Usage
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < dataPrefixLength { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(data[dataPrefixLength:], done) {
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case relaymode.ChatCompletions:
 | 
			
		||||
				var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					dataChan <- data // if error happened, pass the data to client
 | 
			
		||||
					continue         // just ignore the error
 | 
			
		||||
				}
 | 
			
		||||
				if len(streamResponse.Choices) == 0 {
 | 
			
		||||
					// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
					continue // just ignore empty choice
 | 
			
		||||
				}
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				for _, choice := range streamResponse.Choices {
 | 
			
		||||
					responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
				}
 | 
			
		||||
				if streamResponse.Usage != nil {
 | 
			
		||||
					usage = streamResponse.Usage
 | 
			
		||||
				}
 | 
			
		||||
			case relaymode.Completions:
 | 
			
		||||
				dataChan <- data
 | 
			
		||||
				var streamResponse CompletionsStreamResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				for _, choice := range streamResponse.Choices {
 | 
			
		||||
					responseText += choice.Text
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
				data = data[:12]
 | 
			
		||||
			}
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < dataPrefixLength { // ignore blank line or wrong format
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if strings.HasPrefix(data[dataPrefixLength:], done) {
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case relaymode.ChatCompletions:
 | 
			
		||||
			var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				render.StringData(c, data) // if error happened, pass the data to client
 | 
			
		||||
				continue                   // just ignore the error
 | 
			
		||||
			}
 | 
			
		||||
			if len(streamResponse.Choices) == 0 {
 | 
			
		||||
				// but for empty choice, we should not pass it to client, this is for azure
 | 
			
		||||
				continue // just ignore empty choice
 | 
			
		||||
			}
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				responseText += conv.AsString(choice.Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			if streamResponse.Usage != nil {
 | 
			
		||||
				usage = streamResponse.Usage
 | 
			
		||||
			}
 | 
			
		||||
		case relaymode.Completions:
 | 
			
		||||
			render.StringData(c, data)
 | 
			
		||||
			var streamResponse CompletionsStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			for _, choice := range streamResponse.Choices {
 | 
			
		||||
				responseText += choice.Text
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,10 @@ package palm
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
@@ -11,8 +15,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
 | 
			
		||||
@@ -77,58 +79,51 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
 | 
			
		||||
	createdTime := helper.GetTimestamp()
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error closing stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		var palmResponse ChatResponse
 | 
			
		||||
		err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
		fullTextResponse.Id = responseId
 | 
			
		||||
		fullTextResponse.Created = createdTime
 | 
			
		||||
		if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
			responseText = palmResponse.Candidates[0].Content
 | 
			
		||||
		}
 | 
			
		||||
		jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
			stopChan <- true
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		dataChan <- string(jsonResponse)
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error reading stream response: " + err.Error())
 | 
			
		||||
		err := resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var palmResponse ChatResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &palmResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
 | 
			
		||||
	fullTextResponse.Id = responseId
 | 
			
		||||
	fullTextResponse.Created = createdTime
 | 
			
		||||
	if len(palmResponse.Candidates) > 0 {
 | 
			
		||||
		responseText = palmResponse.Candidates[0].Content
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = render.ObjectData(c, string(jsonResponse))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,13 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/conv"
 | 
			
		||||
@@ -17,11 +24,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
 | 
			
		||||
@@ -87,64 +89,46 @@ func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCom
 | 
			
		||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
 | 
			
		||||
	var responseText string
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	scanner.Split(bufio.ScanLines)
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var TencentResponse ChatResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &TencentResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseTencent2OpenAI(&TencentResponse)
 | 
			
		||||
			if len(response.Choices) != 0 {
 | 
			
		||||
				responseText += conv.AsString(response.Choices[0].Delta.Content)
 | 
			
		||||
			}
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		if len(data) < 5 || !strings.HasPrefix(data, "data:") {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
		data = strings.TrimPrefix(data, "data:")
 | 
			
		||||
 | 
			
		||||
		var tencentResponse ChatResponse
 | 
			
		||||
		err := json.Unmarshal([]byte(data), &tencentResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response := streamResponseTencent2OpenAI(&tencentResponse)
 | 
			
		||||
		if len(response.Choices) != 0 {
 | 
			
		||||
			responseText += conv.AsString(response.Choices[0].Delta.Content)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,13 @@ package zhipu
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/render"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/golang-jwt/jwt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
@@ -11,11 +18,6 @@ import (
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/constant"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://open.bigmodel.cn/doc/api#chatglm_std
 | 
			
		||||
@@ -155,66 +157,55 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	metaChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			lines := strings.Split(data, "\n")
 | 
			
		||||
			for i, line := range lines {
 | 
			
		||||
				if len(line) < 5 {
 | 
			
		||||
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
		lines := strings.Split(data, "\n")
 | 
			
		||||
		for i, line := range lines {
 | 
			
		||||
			if len(line) < 5 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				dataSegment := line[5:]
 | 
			
		||||
				if i != len(lines)-1 {
 | 
			
		||||
					dataSegment += "\n"
 | 
			
		||||
				}
 | 
			
		||||
				response := streamResponseZhipu2OpenAI(dataSegment)
 | 
			
		||||
				err := render.ObjectData(c, response)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			} else if strings.HasPrefix(line, "meta:") {
 | 
			
		||||
				metaSegment := line[5:]
 | 
			
		||||
				var zhipuResponse StreamMetaResponse
 | 
			
		||||
				err := json.Unmarshal([]byte(metaSegment), &zhipuResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				if line[:5] == "data:" {
 | 
			
		||||
					dataChan <- line[5:]
 | 
			
		||||
					if i != len(lines)-1 {
 | 
			
		||||
						dataChan <- "\n"
 | 
			
		||||
					}
 | 
			
		||||
				} else if line[:5] == "meta:" {
 | 
			
		||||
					metaChan <- line[5:]
 | 
			
		||||
				response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
				err = render.ObjectData(c, response)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				usage = zhipuUsage
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	common.SetEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			response := streamResponseZhipu2OpenAI(data)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case data := <-metaChan:
 | 
			
		||||
			var zhipuResponse StreamMetaResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &zhipuResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage = zhipuUsage
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := scanner.Err(); err != nil {
 | 
			
		||||
		logger.SysError("error reading stream: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	render.Done(c)
 | 
			
		||||
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user