|  |  |  | @@ -15,6 +15,12 @@ import ( | 
		
	
		
			
				|  |  |  |  | 	"github.com/gin-gonic/gin" | 
		
	
		
			
				|  |  |  |  | ) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | const ( | 
		
	
		
			
				|  |  |  |  | 	APITypeOpenAI = iota | 
		
	
		
			
				|  |  |  |  | 	APITypeClaude | 
		
	
		
			
				|  |  |  |  | 	APITypePaLM | 
		
	
		
			
				|  |  |  |  | ) | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | 
		
	
		
			
				|  |  |  |  | 	channelType := c.GetInt("channel") | 
		
	
		
			
				|  |  |  |  | 	tokenId := c.GetInt("token_id") | 
		
	
	
		
			
				
					
					|  |  |  | @@ -71,33 +77,42 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | 
		
	
		
			
				|  |  |  |  | 			isModelMapped = true | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	apiType := APITypeOpenAI | 
		
	
		
			
				|  |  |  |  | 	if strings.HasPrefix(textRequest.Model, "claude") { | 
		
	
		
			
				|  |  |  |  | 		apiType = APITypeClaude | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	baseURL := common.ChannelBaseURLs[channelType] | 
		
	
		
			
				|  |  |  |  | 	requestURL := c.Request.URL.String() | 
		
	
		
			
				|  |  |  |  | 	if c.GetString("base_url") != "" { | 
		
	
		
			
				|  |  |  |  | 		baseURL = c.GetString("base_url") | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) | 
		
	
		
			
				|  |  |  |  | 	if channelType == common.ChannelTypeAzure { | 
		
	
		
			
				|  |  |  |  | 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 
		
	
		
			
				|  |  |  |  | 		query := c.Request.URL.Query() | 
		
	
		
			
				|  |  |  |  | 		apiVersion := query.Get("api-version") | 
		
	
		
			
				|  |  |  |  | 		if apiVersion == "" { | 
		
	
		
			
				|  |  |  |  | 			apiVersion = c.GetString("api_version") | 
		
	
		
			
				|  |  |  |  | 	switch apiType { | 
		
	
		
			
				|  |  |  |  | 	case APITypeOpenAI: | 
		
	
		
			
				|  |  |  |  | 		if channelType == common.ChannelTypeAzure { | 
		
	
		
			
				|  |  |  |  | 			// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | 
		
	
		
			
				|  |  |  |  | 			query := c.Request.URL.Query() | 
		
	
		
			
				|  |  |  |  | 			apiVersion := query.Get("api-version") | 
		
	
		
			
				|  |  |  |  | 			if apiVersion == "" { | 
		
	
		
			
				|  |  |  |  | 				apiVersion = c.GetString("api_version") | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			requestURL := strings.Split(requestURL, "?")[0] | 
		
	
		
			
				|  |  |  |  | 			requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | 
		
	
		
			
				|  |  |  |  | 			baseURL = c.GetString("base_url") | 
		
	
		
			
				|  |  |  |  | 			task := strings.TrimPrefix(requestURL, "/v1/") | 
		
	
		
			
				|  |  |  |  | 			model_ := textRequest.Model | 
		
	
		
			
				|  |  |  |  | 			model_ = strings.Replace(model_, ".", "", -1) | 
		
	
		
			
				|  |  |  |  | 			// https://github.com/songquanpeng/one-api/issues/67 | 
		
	
		
			
				|  |  |  |  | 			model_ = strings.TrimSuffix(model_, "-0301") | 
		
	
		
			
				|  |  |  |  | 			model_ = strings.TrimSuffix(model_, "-0314") | 
		
	
		
			
				|  |  |  |  | 			model_ = strings.TrimSuffix(model_, "-0613") | 
		
	
		
			
				|  |  |  |  | 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 	case APITypeClaude: | 
		
	
		
			
				|  |  |  |  | 		fullRequestURL = "https://api.anthropic.com/v1/complete" | 
		
	
		
			
				|  |  |  |  | 		if baseURL != "" { | 
		
	
		
			
				|  |  |  |  | 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		requestURL := strings.Split(requestURL, "?")[0] | 
		
	
		
			
				|  |  |  |  | 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) | 
		
	
		
			
				|  |  |  |  | 		baseURL = c.GetString("base_url") | 
		
	
		
			
				|  |  |  |  | 		task := strings.TrimPrefix(requestURL, "/v1/") | 
		
	
		
			
				|  |  |  |  | 		model_ := textRequest.Model | 
		
	
		
			
				|  |  |  |  | 		model_ = strings.Replace(model_, ".", "", -1) | 
		
	
		
			
				|  |  |  |  | 		// https://github.com/songquanpeng/one-api/issues/67 | 
		
	
		
			
				|  |  |  |  | 		model_ = strings.TrimSuffix(model_, "-0301") | 
		
	
		
			
				|  |  |  |  | 		model_ = strings.TrimSuffix(model_, "-0314") | 
		
	
		
			
				|  |  |  |  | 		model_ = strings.TrimSuffix(model_, "-0613") | 
		
	
		
			
				|  |  |  |  | 		fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) | 
		
	
		
			
				|  |  |  |  | 	} else if channelType == common.ChannelTypePaLM { | 
		
	
		
			
				|  |  |  |  | 		err := relayPaLM(textRequest, c) | 
		
	
		
			
				|  |  |  |  | 		return err | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	var promptTokens int | 
		
	
		
			
				|  |  |  |  | 	var completionTokens int | 
		
	
	
		
			
				
					
					|  |  |  | @@ -142,16 +157,58 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | 
		
	
		
			
				|  |  |  |  | 	} else { | 
		
	
		
			
				|  |  |  |  | 		requestBody = c.Request.Body | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	switch apiType { | 
		
	
		
			
				|  |  |  |  | 	case APITypeClaude: | 
		
	
		
			
				|  |  |  |  | 		claudeRequest := ClaudeRequest{ | 
		
	
		
			
				|  |  |  |  | 			Model:             textRequest.Model, | 
		
	
		
			
				|  |  |  |  | 			Prompt:            "", | 
		
	
		
			
				|  |  |  |  | 			MaxTokensToSample: textRequest.MaxTokens, | 
		
	
		
			
				|  |  |  |  | 			StopSequences:     nil, | 
		
	
		
			
				|  |  |  |  | 			Temperature:       textRequest.Temperature, | 
		
	
		
			
				|  |  |  |  | 			TopP:              textRequest.TopP, | 
		
	
		
			
				|  |  |  |  | 			Stream:            textRequest.Stream, | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		if claudeRequest.MaxTokensToSample == 0 { | 
		
	
		
			
				|  |  |  |  | 			claudeRequest.MaxTokensToSample = 1000000 | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		prompt := "" | 
		
	
		
			
				|  |  |  |  | 		for _, message := range textRequest.Messages { | 
		
	
		
			
				|  |  |  |  | 			if message.Role == "user" { | 
		
	
		
			
				|  |  |  |  | 				prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) | 
		
	
		
			
				|  |  |  |  | 			} else if message.Role == "assistant" { | 
		
	
		
			
				|  |  |  |  | 				prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) | 
		
	
		
			
				|  |  |  |  | 			} else { | 
		
	
		
			
				|  |  |  |  | 				// ignore other roles | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			prompt += "\n\nAssistant:" | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		claudeRequest.Prompt = prompt | 
		
	
		
			
				|  |  |  |  | 		jsonStr, err := json.Marshal(claudeRequest) | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		requestBody = bytes.NewBuffer(jsonStr) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) | 
		
	
		
			
				|  |  |  |  | 	if err != nil { | 
		
	
		
			
				|  |  |  |  | 		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	if channelType == common.ChannelTypeAzure { | 
		
	
		
			
				|  |  |  |  | 		key := c.Request.Header.Get("Authorization") | 
		
	
		
			
				|  |  |  |  | 		key = strings.TrimPrefix(key, "Bearer ") | 
		
	
		
			
				|  |  |  |  | 		req.Header.Set("api-key", key) | 
		
	
		
			
				|  |  |  |  | 	} else { | 
		
	
		
			
				|  |  |  |  | 		req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | 
		
	
		
			
				|  |  |  |  | 	apiKey := c.Request.Header.Get("Authorization") | 
		
	
		
			
				|  |  |  |  | 	apiKey = strings.TrimPrefix(apiKey, "Bearer ") | 
		
	
		
			
				|  |  |  |  | 	switch apiType { | 
		
	
		
			
				|  |  |  |  | 	case APITypeOpenAI: | 
		
	
		
			
				|  |  |  |  | 		if channelType == common.ChannelTypeAzure { | 
		
	
		
			
				|  |  |  |  | 			req.Header.Set("api-key", apiKey) | 
		
	
		
			
				|  |  |  |  | 		} else { | 
		
	
		
			
				|  |  |  |  | 			req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 	case APITypeClaude: | 
		
	
		
			
				|  |  |  |  | 		req.Header.Set("x-api-key", apiKey) | 
		
	
		
			
				|  |  |  |  | 		anthropicVersion := c.Request.Header.Get("anthropic-version") | 
		
	
		
			
				|  |  |  |  | 		if anthropicVersion == "" { | 
		
	
		
			
				|  |  |  |  | 			anthropicVersion = "2023-06-01" | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		req.Header.Set("anthropic-version", anthropicVersion) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) | 
		
	
		
			
				|  |  |  |  | 	req.Header.Set("Accept", c.Request.Header.Get("Accept")) | 
		
	
	
		
			
				
					
					|  |  |  | @@ -219,87 +276,198 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 	}() | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 	if isStream { | 
		
	
		
			
				|  |  |  |  | 		scanner := bufio.NewScanner(resp.Body) | 
		
	
		
			
				|  |  |  |  | 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 
		
	
		
			
				|  |  |  |  | 			if atEOF && len(data) == 0 { | 
		
	
		
			
				|  |  |  |  | 				return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 			if i := strings.Index(string(data), "\n"); i >= 0 { | 
		
	
		
			
				|  |  |  |  | 				return i + 1, data[0:i], nil | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 			if atEOF { | 
		
	
		
			
				|  |  |  |  | 				return len(data), data, nil | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  |  | 
		
	
		
			
				|  |  |  |  | 			return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 		}) | 
		
	
		
			
				|  |  |  |  | 		dataChan := make(chan string) | 
		
	
		
			
				|  |  |  |  | 		stopChan := make(chan bool) | 
		
	
		
			
				|  |  |  |  | 		go func() { | 
		
	
		
			
				|  |  |  |  | 			for scanner.Scan() { | 
		
	
		
			
				|  |  |  |  | 				data := scanner.Text() | 
		
	
		
			
				|  |  |  |  | 				if len(data) < 6 { // ignore blank line or wrong format | 
		
	
		
			
				|  |  |  |  | 					continue | 
		
	
		
			
				|  |  |  |  | 	switch apiType { | 
		
	
		
			
				|  |  |  |  | 	case APITypeOpenAI: | 
		
	
		
			
				|  |  |  |  | 		if isStream { | 
		
	
		
			
				|  |  |  |  | 			scanner := bufio.NewScanner(resp.Body) | 
		
	
		
			
				|  |  |  |  | 			scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 
		
	
		
			
				|  |  |  |  | 				if atEOF && len(data) == 0 { | 
		
	
		
			
				|  |  |  |  | 					return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				dataChan <- data | 
		
	
		
			
				|  |  |  |  | 				data = data[6:] | 
		
	
		
			
				|  |  |  |  | 				if !strings.HasPrefix(data, "[DONE]") { | 
		
	
		
			
				|  |  |  |  | 					switch relayMode { | 
		
	
		
			
				|  |  |  |  | 					case RelayModeChatCompletions: | 
		
	
		
			
				|  |  |  |  | 						var streamResponse ChatCompletionsStreamResponse | 
		
	
		
			
				|  |  |  |  | 						err = json.Unmarshal([]byte(data), &streamResponse) | 
		
	
		
			
				|  |  |  |  | 						if err != nil { | 
		
	
		
			
				|  |  |  |  | 							common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 							return | 
		
	
		
			
				|  |  |  |  | 						} | 
		
	
		
			
				|  |  |  |  | 						for _, choice := range streamResponse.Choices { | 
		
	
		
			
				|  |  |  |  | 							streamResponseText += choice.Delta.Content | 
		
	
		
			
				|  |  |  |  | 						} | 
		
	
		
			
				|  |  |  |  | 					case RelayModeCompletions: | 
		
	
		
			
				|  |  |  |  | 						var streamResponse CompletionsStreamResponse | 
		
	
		
			
				|  |  |  |  | 						err = json.Unmarshal([]byte(data), &streamResponse) | 
		
	
		
			
				|  |  |  |  | 						if err != nil { | 
		
	
		
			
				|  |  |  |  | 							common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 							return | 
		
	
		
			
				|  |  |  |  | 						} | 
		
	
		
			
				|  |  |  |  | 						for _, choice := range streamResponse.Choices { | 
		
	
		
			
				|  |  |  |  | 							streamResponseText += choice.Text | 
		
	
		
			
				|  |  |  |  | 				if i := strings.Index(string(data), "\n"); i >= 0 { | 
		
	
		
			
				|  |  |  |  | 					return i + 1, data[0:i], nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				if atEOF { | 
		
	
		
			
				|  |  |  |  | 					return len(data), data, nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 			}) | 
		
	
		
			
				|  |  |  |  | 			dataChan := make(chan string) | 
		
	
		
			
				|  |  |  |  | 			stopChan := make(chan bool) | 
		
	
		
			
				|  |  |  |  | 			go func() { | 
		
	
		
			
				|  |  |  |  | 				for scanner.Scan() { | 
		
	
		
			
				|  |  |  |  | 					data := scanner.Text() | 
		
	
		
			
				|  |  |  |  | 					if len(data) < 6 { // ignore blank line or wrong format | 
		
	
		
			
				|  |  |  |  | 						continue | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 					dataChan <- data | 
		
	
		
			
				|  |  |  |  | 					data = data[6:] | 
		
	
		
			
				|  |  |  |  | 					if !strings.HasPrefix(data, "[DONE]") { | 
		
	
		
			
				|  |  |  |  | 						switch relayMode { | 
		
	
		
			
				|  |  |  |  | 						case RelayModeChatCompletions: | 
		
	
		
			
				|  |  |  |  | 							var streamResponse ChatCompletionsStreamResponse | 
		
	
		
			
				|  |  |  |  | 							err = json.Unmarshal([]byte(data), &streamResponse) | 
		
	
		
			
				|  |  |  |  | 							if err != nil { | 
		
	
		
			
				|  |  |  |  | 								common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 								return | 
		
	
		
			
				|  |  |  |  | 							} | 
		
	
		
			
				|  |  |  |  | 							for _, choice := range streamResponse.Choices { | 
		
	
		
			
				|  |  |  |  | 								streamResponseText += choice.Delta.Content | 
		
	
		
			
				|  |  |  |  | 							} | 
		
	
		
			
				|  |  |  |  | 						case RelayModeCompletions: | 
		
	
		
			
				|  |  |  |  | 							var streamResponse CompletionsStreamResponse | 
		
	
		
			
				|  |  |  |  | 							err = json.Unmarshal([]byte(data), &streamResponse) | 
		
	
		
			
				|  |  |  |  | 							if err != nil { | 
		
	
		
			
				|  |  |  |  | 								common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 								return | 
		
	
		
			
				|  |  |  |  | 							} | 
		
	
		
			
				|  |  |  |  | 							for _, choice := range streamResponse.Choices { | 
		
	
		
			
				|  |  |  |  | 								streamResponseText += choice.Text | 
		
	
		
			
				|  |  |  |  | 							} | 
		
	
		
			
				|  |  |  |  | 						} | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			stopChan <- true | 
		
	
		
			
				|  |  |  |  | 		}() | 
		
	
		
			
				|  |  |  |  | 		c.Writer.Header().Set("Content-Type", "text/event-stream") | 
		
	
		
			
				|  |  |  |  | 		c.Writer.Header().Set("Cache-Control", "no-cache") | 
		
	
		
			
				|  |  |  |  | 		c.Writer.Header().Set("Connection", "keep-alive") | 
		
	
		
			
				|  |  |  |  | 		c.Writer.Header().Set("Transfer-Encoding", "chunked") | 
		
	
		
			
				|  |  |  |  | 		c.Writer.Header().Set("X-Accel-Buffering", "no") | 
		
	
		
			
				|  |  |  |  | 		c.Stream(func(w io.Writer) bool { | 
		
	
		
			
				|  |  |  |  | 			select { | 
		
	
		
			
				|  |  |  |  | 			case data := <-dataChan: | 
		
	
		
			
				|  |  |  |  | 				if strings.HasPrefix(data, "data: [DONE]") { | 
		
	
		
			
				|  |  |  |  | 					data = data[:12] | 
		
	
		
			
				|  |  |  |  | 				stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			}() | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Content-Type", "text/event-stream") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Cache-Control", "no-cache") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Connection", "keep-alive") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Transfer-Encoding", "chunked") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("X-Accel-Buffering", "no") | 
		
	
		
			
				|  |  |  |  | 			c.Stream(func(w io.Writer) bool { | 
		
	
		
			
				|  |  |  |  | 				select { | 
		
	
		
			
				|  |  |  |  | 				case data := <-dataChan: | 
		
	
		
			
				|  |  |  |  | 					if strings.HasPrefix(data, "data: [DONE]") { | 
		
	
		
			
				|  |  |  |  | 						data = data[:12] | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 					// some implementations may add \r at the end of data | 
		
	
		
			
				|  |  |  |  | 					data = strings.TrimSuffix(data, "\r") | 
		
	
		
			
				|  |  |  |  | 					c.Render(-1, common.CustomEvent{Data: data}) | 
		
	
		
			
				|  |  |  |  | 					return true | 
		
	
		
			
				|  |  |  |  | 				case <-stopChan: | 
		
	
		
			
				|  |  |  |  | 					return false | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				// some implementations may add \r at the end of data | 
		
	
		
			
				|  |  |  |  | 				data = strings.TrimSuffix(data, "\r") | 
		
	
		
			
				|  |  |  |  | 				c.Render(-1, common.CustomEvent{Data: data}) | 
		
	
		
			
				|  |  |  |  | 				return true | 
		
	
		
			
				|  |  |  |  | 			case <-stopChan: | 
		
	
		
			
				|  |  |  |  | 				return false | 
		
	
		
			
				|  |  |  |  | 			}) | 
		
	
		
			
				|  |  |  |  | 			err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 		}) | 
		
	
		
			
				|  |  |  |  | 		err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			return nil | 
		
	
		
			
				|  |  |  |  | 		} else { | 
		
	
		
			
				|  |  |  |  | 			if consumeQuota { | 
		
	
		
			
				|  |  |  |  | 				responseBody, err := io.ReadAll(resp.Body) | 
		
	
		
			
				|  |  |  |  | 				if err != nil { | 
		
	
		
			
				|  |  |  |  | 					return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 				if err != nil { | 
		
	
		
			
				|  |  |  |  | 					return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				err = json.Unmarshal(responseBody, &textResponse) | 
		
	
		
			
				|  |  |  |  | 				if err != nil { | 
		
	
		
			
				|  |  |  |  | 					return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				if textResponse.Error.Type != "" { | 
		
	
		
			
				|  |  |  |  | 					return &OpenAIErrorWithStatusCode{ | 
		
	
		
			
				|  |  |  |  | 						OpenAIError: textResponse.Error, | 
		
	
		
			
				|  |  |  |  | 						StatusCode:  resp.StatusCode, | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				// Reset response body | 
		
	
		
			
				|  |  |  |  | 				resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			// We shouldn't set the header before we parse the response body, because the parse part may fail. | 
		
	
		
			
				|  |  |  |  | 			// And then we will have to send an error response, but in this case, the header has already been set. | 
		
	
		
			
				|  |  |  |  | 			// So the client will be confused by the response. | 
		
	
		
			
				|  |  |  |  | 			// For example, Postman will report error, and we cannot check the response at all. | 
		
	
		
			
				|  |  |  |  | 			for k, v := range resp.Header { | 
		
	
		
			
				|  |  |  |  | 				c.Writer.Header().Set(k, v[0]) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			c.Writer.WriteHeader(resp.StatusCode) | 
		
	
		
			
				|  |  |  |  | 			_, err = io.Copy(c.Writer, resp.Body) | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			return nil | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		return nil | 
		
	
		
			
				|  |  |  |  | 	} else { | 
		
	
		
			
				|  |  |  |  | 		if consumeQuota { | 
		
	
		
			
				|  |  |  |  | 	case APITypeClaude: | 
		
	
		
			
				|  |  |  |  | 		if isStream { | 
		
	
		
			
				|  |  |  |  | 			responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | 
		
	
		
			
				|  |  |  |  | 			createdTime := common.GetTimestamp() | 
		
	
		
			
				|  |  |  |  | 			scanner := bufio.NewScanner(resp.Body) | 
		
	
		
			
				|  |  |  |  | 			scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | 
		
	
		
			
				|  |  |  |  | 				if atEOF && len(data) == 0 { | 
		
	
		
			
				|  |  |  |  | 					return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { | 
		
	
		
			
				|  |  |  |  | 					return i + 4, data[0:i], nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				if atEOF { | 
		
	
		
			
				|  |  |  |  | 					return len(data), data, nil | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				return 0, nil, nil | 
		
	
		
			
				|  |  |  |  | 			}) | 
		
	
		
			
				|  |  |  |  | 			dataChan := make(chan string) | 
		
	
		
			
				|  |  |  |  | 			stopChan := make(chan bool) | 
		
	
		
			
				|  |  |  |  | 			go func() { | 
		
	
		
			
				|  |  |  |  | 				for scanner.Scan() { | 
		
	
		
			
				|  |  |  |  | 					data := scanner.Text() | 
		
	
		
			
				|  |  |  |  | 					if !strings.HasPrefix(data, "event: completion") { | 
		
	
		
			
				|  |  |  |  | 						continue | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 					data = strings.TrimPrefix(data, "event: completion\r\ndata: ") | 
		
	
		
			
				|  |  |  |  | 					dataChan <- data | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 				stopChan <- true | 
		
	
		
			
				|  |  |  |  | 			}() | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Content-Type", "text/event-stream") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Cache-Control", "no-cache") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Connection", "keep-alive") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Transfer-Encoding", "chunked") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("X-Accel-Buffering", "no") | 
		
	
		
			
				|  |  |  |  | 			c.Stream(func(w io.Writer) bool { | 
		
	
		
			
				|  |  |  |  | 				select { | 
		
	
		
			
				|  |  |  |  | 				case data := <-dataChan: | 
		
	
		
			
				|  |  |  |  | 					// some implementations may add \r at the end of data | 
		
	
		
			
				|  |  |  |  | 					data = strings.TrimSuffix(data, "\r") | 
		
	
		
			
				|  |  |  |  | 					var claudeResponse ClaudeResponse | 
		
	
		
			
				|  |  |  |  | 					err = json.Unmarshal([]byte(data), &claudeResponse) | 
		
	
		
			
				|  |  |  |  | 					if err != nil { | 
		
	
		
			
				|  |  |  |  | 						common.SysError("error unmarshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 						return true | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 					streamResponseText += claudeResponse.Completion | 
		
	
		
			
				|  |  |  |  | 					var choice ChatCompletionsStreamResponseChoice | 
		
	
		
			
				|  |  |  |  | 					choice.Delta.Content = claudeResponse.Completion | 
		
	
		
			
				|  |  |  |  | 					choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason) | 
		
	
		
			
				|  |  |  |  | 					var response ChatCompletionsStreamResponse | 
		
	
		
			
				|  |  |  |  | 					response.Id = responseId | 
		
	
		
			
				|  |  |  |  | 					response.Created = createdTime | 
		
	
		
			
				|  |  |  |  | 					response.Object = "chat.completion.chunk" | 
		
	
		
			
				|  |  |  |  | 					response.Model = textRequest.Model | 
		
	
		
			
				|  |  |  |  | 					response.Choices = []ChatCompletionsStreamResponseChoice{choice} | 
		
	
		
			
				|  |  |  |  | 					jsonStr, err := json.Marshal(response) | 
		
	
		
			
				|  |  |  |  | 					if err != nil { | 
		
	
		
			
				|  |  |  |  | 						common.SysError("error marshalling stream response: " + err.Error()) | 
		
	
		
			
				|  |  |  |  | 						return true | 
		
	
		
			
				|  |  |  |  | 					} | 
		
	
		
			
				|  |  |  |  | 					c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | 
		
	
		
			
				|  |  |  |  | 					return true | 
		
	
		
			
				|  |  |  |  | 				case <-stopChan: | 
		
	
		
			
				|  |  |  |  | 					c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | 
		
	
		
			
				|  |  |  |  | 					return false | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 			}) | 
		
	
		
			
				|  |  |  |  | 			err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			return nil | 
		
	
		
			
				|  |  |  |  | 		} else { | 
		
	
		
			
				|  |  |  |  | 			responseBody, err := io.ReadAll(resp.Body) | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) | 
		
	
	
		
			
				
					
					|  |  |  | @@ -308,35 +476,54 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			err = json.Unmarshal(responseBody, &textResponse) | 
		
	
		
			
				|  |  |  |  | 			var claudeResponse ClaudeResponse | 
		
	
		
			
				|  |  |  |  | 			err = json.Unmarshal(responseBody, &claudeResponse) | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			if textResponse.Error.Type != "" { | 
		
	
		
			
				|  |  |  |  | 			if claudeResponse.Error.Type != "" { | 
		
	
		
			
				|  |  |  |  | 				return &OpenAIErrorWithStatusCode{ | 
		
	
		
			
				|  |  |  |  | 					OpenAIError: textResponse.Error, | 
		
	
		
			
				|  |  |  |  | 					StatusCode:  resp.StatusCode, | 
		
	
		
			
				|  |  |  |  | 					OpenAIError: OpenAIError{ | 
		
	
		
			
				|  |  |  |  | 						Message: claudeResponse.Error.Message, | 
		
	
		
			
				|  |  |  |  | 						Type:    claudeResponse.Error.Type, | 
		
	
		
			
				|  |  |  |  | 						Param:   "", | 
		
	
		
			
				|  |  |  |  | 						Code:    claudeResponse.Error.Type, | 
		
	
		
			
				|  |  |  |  | 					}, | 
		
	
		
			
				|  |  |  |  | 					StatusCode: resp.StatusCode, | 
		
	
		
			
				|  |  |  |  | 				} | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			// Reset response body | 
		
	
		
			
				|  |  |  |  | 			resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) | 
		
	
		
			
				|  |  |  |  | 			choice := OpenAITextResponseChoice{ | 
		
	
		
			
				|  |  |  |  | 				Index: 0, | 
		
	
		
			
				|  |  |  |  | 				Message: Message{ | 
		
	
		
			
				|  |  |  |  | 					Role:    "assistant", | 
		
	
		
			
				|  |  |  |  | 					Content: strings.TrimPrefix(claudeResponse.Completion, " "), | 
		
	
		
			
				|  |  |  |  | 					Name:    nil, | 
		
	
		
			
				|  |  |  |  | 				}, | 
		
	
		
			
				|  |  |  |  | 				FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model) | 
		
	
		
			
				|  |  |  |  | 			fullTextResponse := OpenAITextResponse{ | 
		
	
		
			
				|  |  |  |  | 				Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | 
		
	
		
			
				|  |  |  |  | 				Object:  "chat.completion", | 
		
	
		
			
				|  |  |  |  | 				Created: common.GetTimestamp(), | 
		
	
		
			
				|  |  |  |  | 				Choices: []OpenAITextResponseChoice{choice}, | 
		
	
		
			
				|  |  |  |  | 				Usage: Usage{ | 
		
	
		
			
				|  |  |  |  | 					PromptTokens:     promptTokens, | 
		
	
		
			
				|  |  |  |  | 					CompletionTokens: completionTokens, | 
		
	
		
			
				|  |  |  |  | 					TotalTokens:      promptTokens + promptTokens, | 
		
	
		
			
				|  |  |  |  | 				}, | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			textResponse.Usage = fullTextResponse.Usage | 
		
	
		
			
				|  |  |  |  | 			jsonResponse, err := json.Marshal(fullTextResponse) | 
		
	
		
			
				|  |  |  |  | 			if err != nil { | 
		
	
		
			
				|  |  |  |  | 				return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 			} | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set("Content-Type", "application/json") | 
		
	
		
			
				|  |  |  |  | 			c.Writer.WriteHeader(resp.StatusCode) | 
		
	
		
			
				|  |  |  |  | 			_, err = c.Writer.Write(jsonResponse) | 
		
	
		
			
				|  |  |  |  | 			return nil | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		// We shouldn't set the header before we parse the response body, because the parse part may fail. | 
		
	
		
			
				|  |  |  |  | 		// And then we will have to send an error response, but in this case, the header has already been set. | 
		
	
		
			
				|  |  |  |  | 		// So the client will be confused by the response. | 
		
	
		
			
				|  |  |  |  | 		// For example, Postman will report error, and we cannot check the response at all. | 
		
	
		
			
				|  |  |  |  | 		for k, v := range resp.Header { | 
		
	
		
			
				|  |  |  |  | 			c.Writer.Header().Set(k, v[0]) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		c.Writer.WriteHeader(resp.StatusCode) | 
		
	
		
			
				|  |  |  |  | 		_, err = io.Copy(c.Writer, resp.Body) | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		err = resp.Body.Close() | 
		
	
		
			
				|  |  |  |  | 		if err != nil { | 
		
	
		
			
				|  |  |  |  | 			return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 		} | 
		
	
		
			
				|  |  |  |  | 		return nil | 
		
	
		
			
				|  |  |  |  | 	default: | 
		
	
		
			
				|  |  |  |  | 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) | 
		
	
		
			
				|  |  |  |  | 	} | 
		
	
		
			
				|  |  |  |  | } | 
		
	
	
		
			
				
					
					| 
							
							
							
						 |  |  |   |