diff --git a/dto/dalle.go b/dto/dalle.go index d366051..d0bba65 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,9 +12,11 @@ type ImageRequest struct { } type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - } + Data []ImageData `json:"data"` + Created int64 `json:"created"` +} +type ImageData struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 88990d1..98728a0 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" ) @@ -15,23 +16,18 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) - if info.RelayMode == constant.RelayModeEmbeddings { + var fullRequestURL string + switch info.RelayMode { + case constant.RelayModeEmbeddings: fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } return fullRequestURL, nil } @@ -57,13 +53,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) return baiduEmbeddingRequest, nil default: - baiduRequest := requestOpenAI2Ali(*request) - return baiduRequest, nil + aliReq := requestOpenAI2Ali(*request) + return aliReq, nil } } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + aliRequest := oaiImage2Ali(request) + return aliRequest, nil +} + func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -71,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = aliStreamHandler(c, resp) - } else { - switch info.RelayMode { - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + if info.IsStream { + err, usage = openai.OpenaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } } return diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index fd1f07a..f51286a 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -60,13 +60,40 @@ type AliUsage struct { TotalTokens int `json:"total_tokens"` } -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` +type TaskResult struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } -type AliChatResponse struct { +type AliOutput struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` +} + +type AliResponse struct { Output AliOutput `json:"output"` Usage AliUsage `json:"usage"` AliError } + +type AliImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go new file mode 100644 index 0000000..160fabf --- /dev/null +++ b/relay/channel/ali/image.go @@ -0,0 +1,177 @@ +package ali + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { + var imageRequest AliImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + +func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) { + url := fmt.Sprintf("/api/v1/tasks/%s", taskID) + + var aliResponse AliResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + common.SysError("updateTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response AliResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + common.SysError("updateTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) { + waitSeconds := 3 + step := 0 + maxStep := 20 + + var taskResponse AliResponse + var responseBody []byte + + for { + step++ + rsp, err, body := updateTask(info, taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse { + imageResponse := dto.ImageResponse{ + Created: info.StartTime.Unix(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + _, b64, err := service.GetImageFromUrl(data.Url) + if err != nil { + common.LogError(c, "get_image_data_failed: "+err.Error()) + continue + } + b64Json = b64 + } else { + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, dto.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse AliResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} diff --git a/relay/channel/ali/relay-ali.go b/relay/channel/ali/text.go similarity index 82% rename from relay/channel/ali/relay-ali.go rename to relay/channel/ali/text.go index 4280b1c..aec857f 100644 --- a/relay/channel/ali/relay-ali.go +++ b/relay/channel/ali/text.go @@ -16,34 +16,13 @@ import ( const EnableSearchModelSuffix = "-internet" -func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) - //prompt := "" - for i := 0; i < len(request.Messages); i++ { - message := request.Messages[i] - messages = append(messages, AliMessage{ - Content: message.StringContent(), - Role: strings.ToLower(message.Role), - }) - } - enableSearch := false - aliModel := request.Model - if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { - enableSearch = true - aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) - } - return &AliChatRequest{ - Model: request.Model, - Input: AliInput{ - //Prompt: prompt, - Messages: messages, - }, - Parameters: AliParameters{ - IncrementalOutput: request.Stream, - Seed: uint64(request.Seed), - EnableSearch: enableSearch, - }, +func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { + if request.TopP >= 1 { + request.TopP = 0.999 + } else if request.TopP <= 0 { + request.TopP = 0.001 } + return &request } func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { @@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { +func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(response.Output.Text) choice := dto.OpenAITextResponseChoice{ Index: 0, @@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { +func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(aliResponse.Output.Text) if aliResponse.Output.FinishReason != "null" { @@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) stopChan := make(chan bool) go func() { @@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var aliResponse AliChatResponse + var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith } func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var aliResponse AliChatResponse + var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil