| 
							
							
							
						 |  |  | @@ -1,11 +1,13 @@ | 
		
	
		
			
				|  |  |  |  | package controller | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | import ( | 
		
	
		
			
				|  |  |  |  | 	"bufio" | 
		
	
		
			
				|  |  |  |  | 	"encoding/json" | 
		
	
		
			
				|  |  |  |  | 	"fmt" | 
		
	
		
			
				|  |  |  |  | 	"io" | 
		
	
		
			
				|  |  |  |  | 	"net/http" | 
		
	
		
			
				|  |  |  |  | 	"one-api/common" | 
		
	
		
			
				|  |  |  |  | 	"strings" | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	"github.com/gin-gonic/gin" | 
		
	
		
			
				|  |  |  |  | ) | 
		
	
	
		
			
				
					
					|  |  |  | @@ -180,50 +182,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | 
		
	
		
			
				|  |  |  |  | 	responseText := "" | 
		
	
		
			
				|  |  |  |  | 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | 
		
	
		
			
				|  |  |  |  | 	createdTime := common.GetTimestamp() | 
		
	
		
			
				|  |  |  |  | 	dataChan := make(chan string) | 
		
	
		
			
				|  |  |  |  | 	stopChan := make(chan bool) | 
		
	
		
			
				|  |  |  |  | 	scanner := bufio.NewScanner(resp.Body) | 
		
	
		
			
				|  |  |  |  | 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 
		
	
		
			
				|  |  |  |  | 		if atEOF && len(data) == 0 { | 
		
	
		
			
				|  |  |  |  | 			return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		if i := strings.Index(string(data), "\n"); i >= 0 { | 
		
	
		
			
				|  |  |  |  | 			return i + 1, data[0:i], nil | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		if atEOF { | 
		
	
		
			
				|  |  |  |  | 			return len(data), data, nil | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 	}) | 
		
	
		
			
				|  |  |  |  | 	go func() { | 
		
	
		
			
				|  |  |  |  | 		responseBody, err := io.ReadAll(resp.Body) | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			common.SysError("error reading stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 			stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			return | 
		
	
		
			
				|  |  |  |  | 		for scanner.Scan() { | 
		
	
		
			
				|  |  |  |  | 			data := scanner.Text() | 
		
	
		
			
				|  |  |  |  | 			data = strings.TrimSpace(data) | 
		
	
		
			
				|  |  |  |  | 			if !strings.HasPrefix(data, "\"text\": \"") { | 
		
	
		
			
				|  |  |  |  | 				continue | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			data = strings.TrimPrefix(data, "\"text\": \"") | 
		
	
		
			
				|  |  |  |  | 			data = strings.TrimSuffix(data, "\"") | 
		
	
		
			
				|  |  |  |  | 			dataChan <- data | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			common.SysError("error closing stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 			stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			return | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		var geminiResponse GeminiChatResponse | 
		
	
		
			
				|  |  |  |  | 		err = json.Unmarshal(responseBody, &geminiResponse) | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 			stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			return | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) | 
		
	
		
			
				|  |  |  |  | 		fullTextResponse.Id = responseId | 
		
	
		
			
				|  |  |  |  | 		fullTextResponse.Created = createdTime | 
		
	
		
			
				|  |  |  |  | 		if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { | 
		
	
		
			
				|  |  |  |  | 			responseText += geminiResponse.Candidates[0].Content.Parts[0].Text | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		jsonResponse, err := json.Marshal(fullTextResponse) | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			common.SysError("error marshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 			stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			return | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		dataChan <- string(jsonResponse) | 
		
	
		
			
				|  |  |  |  | 		stopChan <- true | 
		
	
		
			
				|  |  |  |  | 	}() | 
		
	
		
			
				|  |  |  |  | 	setEventStreamHeaders(c) | 
		
	
		
			
				|  |  |  |  | 	c.Stream(func(w io.Writer) bool { | 
		
	
		
			
				|  |  |  |  | 		select { | 
		
	
		
			
				|  |  |  |  | 		case data := <-dataChan: | 
		
	
		
			
				|  |  |  |  | 			c.Render(-1, common.CustomEvent{Data: "data: " + data}) | 
		
	
		
			
				|  |  |  |  | 			// this is used to prevent annoying \ related format bug | 
		
	
		
			
				|  |  |  |  | 			data = fmt.Sprintf("{\"content\": \"%s\"}", data) | 
		
	
		
			
				|  |  |  |  | 			type dummyStruct struct { | 
		
	
		
			
				|  |  |  |  | 				Content string `json:"content"` | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			var dummy dummyStruct | 
		
	
		
			
				|  |  |  |  | 			err := json.Unmarshal([]byte(data), &dummy) | 
		
	
		
			
				|  |  |  |  | 			responseText += dummy.Content | 
		
	
		
			
				|  |  |  |  | 			var choice ChatCompletionsStreamResponseChoice | 
		
	
		
			
				|  |  |  |  | 			choice.Delta.Content = dummy.Content | 
		
	
		
			
				|  |  |  |  | 			response := ChatCompletionsStreamResponse{ | 
		
	
		
			
				|  |  |  |  | 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | 
		
	
		
			
				|  |  |  |  | 				Object:  "chat.completion.chunk", | 
		
	
		
			
				|  |  |  |  | 				Created: common.GetTimestamp(), | 
		
	
		
			
				|  |  |  |  | 				Model:   "gemini-pro", | 
		
	
		
			
				|  |  |  |  | 				Choices: []ChatCompletionsStreamResponseChoice{choice}, | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			jsonResponse, err := json.Marshal(response) | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				common.SysError("error marshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 				return true | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | 
		
	
		
			
				|  |  |  |  | 			return true | 
		
	
		
			
				|  |  |  |  | 		case <-stopChan: | 
		
	
		
			
				|  |  |  |  | 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | 
		
	
	
		
			
				
					
					|  |  |  |   |