mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	refactor: refactor openai related code
This commit is contained in:
		
							
								
								
									
										133
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								controller/relay-openai.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,133 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	responseText := ""
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
				switch relayMode {
 | 
			
		||||
				case RelayModeChatCompletions:
 | 
			
		||||
					var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Delta.Content
 | 
			
		||||
					}
 | 
			
		||||
				case RelayModeCompletions:
 | 
			
		||||
					var streamResponse CompletionsStreamResponse
 | 
			
		||||
					err := json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
					for _, choice := range streamResponse.Choices {
 | 
			
		||||
						responseText += choice.Text
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
				data = data[:12]
 | 
			
		||||
			}
 | 
			
		||||
			// some implementations may add \r at the end of data
 | 
			
		||||
			data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 | 
			
		||||
	}
 | 
			
		||||
	return nil, responseText
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}
 | 
			
		||||
		if textResponse.Error.Type != "" {
 | 
			
		||||
			return &OpenAIErrorWithStatusCode{
 | 
			
		||||
				OpenAIError: textResponse.Error,
 | 
			
		||||
				StatusCode:  resp.StatusCode,
 | 
			
		||||
			}, nil
 | 
			
		||||
		}
 | 
			
		||||
		// Reset response body
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	}
 | 
			
		||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
	// So the client will be confused by the response.
 | 
			
		||||
	// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err := io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
@@ -256,119 +255,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
			scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
				if atEOF && len(data) == 0 {
 | 
			
		||||
					return 0, nil, nil
 | 
			
		||||
				}
 | 
			
		||||
				if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
					return i + 1, data[0:i], nil
 | 
			
		||||
				}
 | 
			
		||||
				if atEOF {
 | 
			
		||||
					return len(data), data, nil
 | 
			
		||||
				}
 | 
			
		||||
				return 0, nil, nil
 | 
			
		||||
			})
 | 
			
		||||
			dataChan := make(chan string)
 | 
			
		||||
			stopChan := make(chan bool)
 | 
			
		||||
			go func() {
 | 
			
		||||
				for scanner.Scan() {
 | 
			
		||||
					data := scanner.Text()
 | 
			
		||||
					if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					dataChan <- data
 | 
			
		||||
					data = data[6:]
 | 
			
		||||
					if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
						switch relayMode {
 | 
			
		||||
						case RelayModeChatCompletions:
 | 
			
		||||
							var streamResponse ChatCompletionsStreamResponse
 | 
			
		||||
							err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
								return
 | 
			
		||||
							}
 | 
			
		||||
							for _, choice := range streamResponse.Choices {
 | 
			
		||||
								streamResponseText += choice.Delta.Content
 | 
			
		||||
							}
 | 
			
		||||
						case RelayModeCompletions:
 | 
			
		||||
							var streamResponse CompletionsStreamResponse
 | 
			
		||||
							err = json.Unmarshal([]byte(data), &streamResponse)
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
								return
 | 
			
		||||
							}
 | 
			
		||||
							for _, choice := range streamResponse.Choices {
 | 
			
		||||
								streamResponseText += choice.Text
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				stopChan <- true
 | 
			
		||||
			}()
 | 
			
		||||
			c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
			c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
			c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
			c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
			c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
			c.Stream(func(w io.Writer) bool {
 | 
			
		||||
				select {
 | 
			
		||||
				case data := <-dataChan:
 | 
			
		||||
					if strings.HasPrefix(data, "data: [DONE]") {
 | 
			
		||||
						data = data[:12]
 | 
			
		||||
					}
 | 
			
		||||
					// some implementations may add \r at the end of data
 | 
			
		||||
					data = strings.TrimSuffix(data, "\r")
 | 
			
		||||
					c.Render(-1, common.CustomEvent{Data: data})
 | 
			
		||||
					return true
 | 
			
		||||
				case <-stopChan:
 | 
			
		||||
					return false
 | 
			
		||||
				}
 | 
			
		||||
			})
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			err, responseText := openaiStreamHandler(c, resp, relayMode)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			streamResponseText = responseText
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			if consumeQuota {
 | 
			
		||||
				responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				}
 | 
			
		||||
				err = resp.Body.Close()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				}
 | 
			
		||||
				err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				}
 | 
			
		||||
				if textResponse.Error.Type != "" {
 | 
			
		||||
					return &OpenAIErrorWithStatusCode{
 | 
			
		||||
						OpenAIError: textResponse.Error,
 | 
			
		||||
						StatusCode:  resp.StatusCode,
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				// Reset response body
 | 
			
		||||
				resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
			}
 | 
			
		||||
			// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
			// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
			// So the client will be confused by the response.
 | 
			
		||||
			// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
			for k, v := range resp.Header {
 | 
			
		||||
				c.Writer.Header().Set(k, v[0])
 | 
			
		||||
			}
 | 
			
		||||
			c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
			_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
			err, usage := openaiHandler(c, resp, consumeQuota)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
			}
 | 
			
		||||
			err = resp.Body.Close()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			textResponse.Usage = *usage
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeClaude:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user