mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-13 03:43:44 +08:00
feat: support reverse proxy of Chanzhaoyu/chatgpt-web
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -27,6 +28,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
||||
} else if channel.Type == common.ChannelTypeChatGPTWeb {
|
||||
if channel.BaseURL != "" {
|
||||
requestURL = channel.BaseURL
|
||||
}
|
||||
requestURL += "/api/chat-process"
|
||||
} else {
|
||||
if channel.BaseURL != "" {
|
||||
requestURL = channel.BaseURL
|
||||
@@ -35,6 +41,41 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(request)
|
||||
|
||||
if channel.Type == common.ChannelTypeChatGPTWeb {
|
||||
// Get system message from Message json, Role == "system"
|
||||
var systemMessage Message
|
||||
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
systemMessage = message
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var prompt string
|
||||
|
||||
// Get all the Message, Roles from request.Messages, and format it into string by
|
||||
// ||> role: content
|
||||
for _, message := range request.Messages {
|
||||
// Exclude system message
|
||||
if message.Role == "system" {
|
||||
continue
|
||||
}
|
||||
prompt += "||> " + message.Role + ": " + message.Content + "\n"
|
||||
}
|
||||
|
||||
// Construct json data without adding escape character
|
||||
map1 := map[string]string{
|
||||
"prompt": prompt,
|
||||
"systemMessage": systemMessage.Content,
|
||||
"temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64),
|
||||
"top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64),
|
||||
}
|
||||
|
||||
// Convert map to json string
|
||||
jsonData, err = json.Marshal(map1)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -104,52 +145,83 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
common.SysError("invalid stream response: " + data)
|
||||
continue
|
||||
}
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
// Remove event: event in the front or back
|
||||
data = strings.TrimPrefix(data, "event: event")
|
||||
data = strings.TrimSuffix(data, "event: event")
|
||||
// Remove everything, only keep `data: {...}` <--- this is the json
|
||||
// Find the start and end indices of `data: {...}` substring
|
||||
startIndex := strings.Index(data, "data:")
|
||||
endIndex := strings.LastIndex(data, "}")
|
||||
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
// Remove event: event in the front or back
|
||||
data = strings.TrimPrefix(data, "event: event")
|
||||
data = strings.TrimSuffix(data, "event: event")
|
||||
// Remove everything, only keep `data: {...}` <--- this is the json
|
||||
// Find the start and end indices of `data: {...}` substring
|
||||
startIndex := strings.Index(data, "data:")
|
||||
endIndex := strings.LastIndex(data, "}")
|
||||
|
||||
// If both indices are found and end index is greater than start index
|
||||
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
|
||||
// Extract the `data: {...}` substring
|
||||
data = data[startIndex : endIndex+1]
|
||||
}
|
||||
// If both indices are found and end index is greater than start index
|
||||
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
|
||||
// Extract the `data: {...}` substring
|
||||
data = data[startIndex : endIndex+1]
|
||||
}
|
||||
|
||||
// Trim whitespace and newlines from the modified data string
|
||||
data = strings.TrimSpace(data)
|
||||
}
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
// Prinnt the body in string
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
common.SysError("error unmarshalling stream response: " + err.Error() + " " + buf.String())
|
||||
return err
|
||||
// Trim whitespace and newlines from the modified data string
|
||||
data = strings.TrimSpace(data)
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Delta.Content
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
// Prinnt the body in string
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
common.SysError("error unmarshalling stream response: " + err.Error() + " " + buf.String())
|
||||
return err
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
} else {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
} else if channel.Type == common.ChannelTypeChatGPTWeb {
|
||||
// data may contain multiple json objects, so we need to split them
|
||||
// they are "{....}{....}{....}" or "{....}\n{....}\n{....}" or "{....}"
|
||||
|
||||
// remove all spaces and newlines outside of json objects
|
||||
jsonObjs := strings.Split(data, "\n") // Split the data into multiple JSON objects
|
||||
for _, jsonObj := range jsonObjs {
|
||||
if jsonObj == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var chatResponse ChatGptWebChatResponse
|
||||
err = json.Unmarshal([]byte(jsonObj), &chatResponse)
|
||||
if err != nil {
|
||||
// Print the body in string
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
|
||||
return err
|
||||
}
|
||||
|
||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
|
||||
for _, choice := range chatResponse.Detail.Choices {
|
||||
log.Print(choice.Delta.Content)
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check if streaming is complete and streamResponseText is populated
|
||||
if streamResponseText == "" || !done {
|
||||
if streamResponseText == "" || !done && channel.Type != common.ChannelTypeChatGPTWeb {
|
||||
return errors.New("Streaming not complete")
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -114,6 +115,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||
} else if channelType == common.ChannelTypeChatGPTWeb {
|
||||
// remove /v1/chat/completions from request url
|
||||
requestURL := strings.Split(requestURL, "/v1/chat/completions")[0]
|
||||
requestURL += "/api/chat-process"
|
||||
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
} else if channelType == common.ChannelTypePaLM {
|
||||
err := relayPaLM(textRequest, c)
|
||||
return err
|
||||
@@ -182,6 +189,57 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
|
||||
requestBody = bytes.NewBuffer(bodyBytes)
|
||||
}
|
||||
|
||||
if channelType == common.ChannelTypeChatGPTWeb {
|
||||
// Get system message from Message json, Role == "system"
|
||||
var reqBody ChatRequest
|
||||
var systemMessage Message
|
||||
|
||||
// Parse requestBody into systemMessage
|
||||
err := json.NewDecoder(requestBody).Decode(&reqBody)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decode_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for _, message := range reqBody.Messages {
|
||||
if message.Role == "system" {
|
||||
systemMessage = message
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var prompt string
|
||||
|
||||
// Get all the Message, Roles from request.Messages, and format it into string by
|
||||
// ||> role: content
|
||||
for _, message := range reqBody.Messages {
|
||||
// Exclude system message
|
||||
if message.Role == "system" {
|
||||
continue
|
||||
}
|
||||
prompt += "||> " + message.Role + ": " + message.Content + "\n"
|
||||
}
|
||||
|
||||
// Construct json data without adding escape character
|
||||
map1 := map[string]string{
|
||||
"prompt": prompt,
|
||||
"systemMessage": systemMessage.Content,
|
||||
"temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64),
|
||||
"top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64),
|
||||
}
|
||||
|
||||
// Convert map to json string
|
||||
jsonData, err := json.Marshal(map1)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_json_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Convert json string to io.Reader
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
@@ -235,7 +293,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var textResponse TextResponse
|
||||
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/octet-stream")
|
||||
var streamResponseText string
|
||||
|
||||
defer func() {
|
||||
@@ -286,82 +344,129 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}()
|
||||
|
||||
if isStream {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // must be something wrong!
|
||||
common.SysError("invalid stream response: " + data)
|
||||
continue
|
||||
}
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
// Remove event: event in the front or back
|
||||
data = strings.TrimPrefix(data, "event: event")
|
||||
data = strings.TrimSuffix(data, "event: event")
|
||||
// Remove everything, only keep `data: {...}` <--- this is the json
|
||||
// Find the start and end indices of `data: {...}` substring
|
||||
startIndex := strings.Index(data, "data:")
|
||||
endIndex := strings.LastIndex(data, "}")
|
||||
|
||||
// If both indices are found and end index is greater than start index
|
||||
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
|
||||
// Extract the `data: {...}` substring
|
||||
data = data[startIndex : endIndex+1]
|
||||
if channelType == common.ChannelTypeChatGPTWeb {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
var chatResponse ChatGptWebChatResponse
|
||||
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
|
||||
|
||||
if err != nil {
|
||||
log.Println("error unmarshal chat response: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// Trim whitespace and newlines from the modified data string
|
||||
data = strings.TrimSpace(data)
|
||||
}
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
|
||||
for _, choice := range chatResponse.Detail.Choices {
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
case RelayModeCompletions:
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Text
|
||||
|
||||
returnObj := map[string]interface{}{
|
||||
"id": chatResponse.ID,
|
||||
"object": chatResponse.Detail.Object,
|
||||
"created": chatResponse.Detail.Created,
|
||||
"model": chatResponse.Detail.Model,
|
||||
"choices": []map[string]interface{}{
|
||||
// set finish_reason to null in json
|
||||
{
|
||||
"finish_reason": nil,
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"content": choice.Delta.Content,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(returnObj)
|
||||
|
||||
dataChan <- "data: " + string(jsonData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
stopChan <- true
|
||||
}()
|
||||
} else {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // must be something wrong!
|
||||
common.SysError("invalid stream response: " + data)
|
||||
continue
|
||||
}
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
// Remove event: event in the front or back
|
||||
data = strings.TrimPrefix(data, "event: event")
|
||||
data = strings.TrimSuffix(data, "event: event")
|
||||
// Remove everything, only keep `data: {...}` <--- this is the json
|
||||
// Find the start and end indices of `data: {...}` substring
|
||||
startIndex := strings.Index(data, "data:")
|
||||
endIndex := strings.LastIndex(data, "}")
|
||||
|
||||
// If both indices are found and end index is greater than start index
|
||||
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
|
||||
// Extract the `data: {...}` substring
|
||||
data = data[startIndex : endIndex+1]
|
||||
}
|
||||
|
||||
// Trim whitespace and newlines from the modified data string
|
||||
data = strings.TrimSpace(data)
|
||||
}
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
case RelayModeCompletions:
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
streamResponseText += choice.Text
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
@@ -373,6 +478,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
log.Print(data)
|
||||
c.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
|
||||
@@ -46,6 +46,9 @@ type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
// -1.0 to 1.0
|
||||
Temperature float64 `json:"temperature"`
|
||||
TopP float64 `json:"top_p"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
@@ -102,6 +105,32 @@ type CompletionsStreamResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatGptWebDetail struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatGptWebChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatGptWebChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
} `json:"delta"`
|
||||
Index int `json:"index"`
|
||||
Finish_Reason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type ChatGptWebChatResponse struct {
|
||||
Role string `json:"role"`
|
||||
ID string `json:"id"`
|
||||
ParentMessageID string `json:"parentMessageId"`
|
||||
Text string `json:"text"`
|
||||
Delta string `json:"delta"`
|
||||
Detail *ChatGptWebDetail `json:"detail"`
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
|
||||
Reference in New Issue
Block a user