diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36798711..3034a547 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,13 +1,13 @@ name: CI # This setup assumes that you run the unit tests with code coverage in the same -# workflow that will also print the coverage report as comment to the pull request. +# workflow that will also print the coverage report as comment to the pull request. # Therefore, you need to trigger this workflow when a pull request is (re)opened or # when new code is pushed to the branch of the pull request. In addition, you also -# need to trigger this workflow when new code is pushed to the main branch because +# need to trigger this workflow when new code is pushed to the main branch because # we need to upload the code coverage results as artifact for the main branch as # well since it will be the baseline code coverage. -# +# # We do not want to trigger the workflow for pushes to *any* branch because this # would trigger our jobs twice on pull requests (once from "push" event and once # from "pull_request->synchronize") @@ -31,7 +31,7 @@ jobs: with: go-version: ^1.22 - # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a + # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") # in the next step as well as the next job. - name: Test diff --git a/.gitignore b/.gitignore index 4e431e65..0cedb4b4 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ logs data /web/node_modules cmd.md -.env \ No newline at end of file +.env +/one-api diff --git a/README.md b/README.md index fb137c23..a9db89a8 100644 --- a/README.md +++ b/README.md @@ -115,8 +115,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 - + 支持使用飞书进行授权登录。 - + [GitHub 开放授权](https://github.com/settings/applications/new)。 + + 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)([这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login))。 + + 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 diff --git a/common/render/render.go b/common/render/render.go index 646b3777..e565c0b7 100644 --- a/common/render/render.go +++ b/common/render/render.go @@ -3,9 +3,10 @@ package render import ( "encoding/json" "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "strings" ) func StringData(c *gin.Context, str string) { diff --git a/controller/auth/lark.go b/controller/auth/lark.go index eb06dde9..39088b3c 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -40,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { if err != nil { return nil, err } - req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } diff --git a/one-api b/one-api deleted file mode 100755 index 4c9190bb..00000000 Binary files a/one-api and /dev/null differ diff --git a/relay/adaptor.go b/relay/adaptor.go index 711e63bd..03e83903 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -16,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/proxy" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" "github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &vertexai.Adaptor{} case apitype.Proxy: return &proxy.Adaptor{} + case apitype.Replicate: + return &replicate.Adaptor{} } return nil } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index aacdba1a..be4804c2 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -20,4 +20,7 @@ var ModelList = []string{ "dall-e-2", "dall-e-3", "whisper-1", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", + "o1", "o1-2024-12-17", + "o1-preview", "o1-preview-2024-09-12", + "o1-mini", "o1-mini-2024-09-12", } diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go index 7d73303b..47c2a882 100644 --- a/relay/adaptor/openai/helper.go +++ b/relay/adaptor/openai/helper.go @@ -2,15 +2,16 @@ package openai import ( "fmt" + "strings" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" - "strings" ) -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { +func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { usage := &model.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.CompletionTokens = CountTokenText(responseText, modelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage } diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go index ba0cab7d..83beadba 100644 --- a/relay/adaptor/openai/util.go +++ b/relay/adaptor/openai/util.go @@ -1,8 +1,16 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" +) func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) + Error := model.Error{ Message: err.Error(), Type: "one_api_error", diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go new file mode 100644 index 00000000..a60a7de3 --- /dev/null +++ b/relay/adaptor/replicate/adaptor.go @@ -0,0 +1,136 @@ +package replicate + +import ( + "fmt" + "io" + "net/http" + "slices" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return DrawImageRequest{ + Input: ImageInput{ + Steps: 25, + Prompt: request.Prompt, + Guidance: 3, + Seed: int(time.Now().UnixNano()), + SafetyTolerance: 5, + NImages: 1, // replicate will always return 1 image + Width: 1440, + Height: 1440, + AspectRatio: "1:1", + }, + }, nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if !request.Stream { + // TODO: support non-stream mode + return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") + } + + // Build the prompt from OpenAI messages + var promptBuilder strings.Builder + for _, message := range request.Messages { + switch msgCnt := message.Content.(type) { + case string: + promptBuilder.WriteString(message.Role) + promptBuilder.WriteString(": ") + promptBuilder.WriteString(msgCnt) + promptBuilder.WriteString("\n") + default: + } + } + + replicateRequest := ReplicateChatRequest{ + Input: ChatInput{ + Prompt: promptBuilder.String(), + MaxTokens: request.MaxTokens, + Temperature: 1.0, + TopP: 1.0, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + }, + } + + // Map optional fields + if request.Temperature != nil { + replicateRequest.Input.Temperature = *request.Temperature + } + if request.TopP != nil { + replicateRequest.Input.TopP = *request.TopP + } + if request.PresencePenalty != nil { + replicateRequest.Input.PresencePenalty = *request.PresencePenalty + } + if request.FrequencyPenalty != nil { + replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty + } + if request.MaxTokens > 0 { + replicateRequest.Input.MaxTokens = request.MaxTokens + } else if request.MaxTokens == 0 { + replicateRequest.Input.MaxTokens = 500 + } + + return replicateRequest, nil +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + if !slices.Contains(ModelList, meta.OriginModelName) { + return "", errors.Errorf("model %s not supported", meta.OriginModelName) + } + + return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send request to replicate") + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.ImagesGenerations: + err, usage = ImageHandler(c, resp) + case relaymode.ChatCompletions: + err, usage = ChatHandler(c, resp) + default: + err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "replicate" +} diff --git a/relay/adaptor/replicate/chat.go b/relay/adaptor/replicate/chat.go new file mode 100644 index 00000000..4051f85c --- /dev/null +++ b/relay/adaptor/replicate/chat.go @@ -0,0 +1,191 @@ +package replicate + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func ChatHandler(c *gin.Context, resp *http.Response) ( + srvErr *model.ErrorWithStatusCode, usage *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ChatResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ChatResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + if taskData.URLs.Stream == "" { + return errors.New("stream url is empty") + } + + // request stream url + responseText, err := chatStreamHandler(c, taskData.URLs.Stream) + if err != nil { + return errors.Wrap(err, "chat stream handler") + } + + ctxMeta := meta.GetByContext(c) + usage = openai.ResponseText2Usage(responseText, + ctxMeta.ActualModelName, ctxMeta.PromptTokens) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, usage +} + +const ( + eventPrefix = "event: " + dataPrefix = "data: " + done = "[DONE]" +) + +func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { + // request stream endpoint + streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) + if err != nil { + return "", errors.Wrap(err, "new request to stream") + } + + streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-store") + + resp, err := http.DefaultClient.Do(streamReq) + if err != nil { + return "", errors.Wrap(err, "do request to stream") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(resp.Body) + return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + doneRendered := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Handle comments starting with ':' + if strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE fields + if strings.HasPrefix(line, eventPrefix) { + event := strings.TrimSpace(line[len(eventPrefix):]) + var data string + // Read the following lines to get data and id + for scanner.Scan() { + nextLine := scanner.Text() + if nextLine == "" { + break + } + if strings.HasPrefix(nextLine, dataPrefix) { + data = nextLine[len(dataPrefix):] + } else if strings.HasPrefix(nextLine, "id:") { + // id = strings.TrimSpace(nextLine[len("id:"):]) + } + } + + if event == "output" { + render.StringData(c, data) + responseText += data + } else if event == "done" { + render.Done(c) + doneRendered = true + break + } + } + } + + if err := scanner.Err(); err != nil { + return "", errors.Wrap(err, "scan stream") + } + + if !doneRendered { + render.Done(c) + } + + return responseText, nil +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go new file mode 100644 index 00000000..989142c9 --- /dev/null +++ b/relay/adaptor/replicate/constant.go @@ -0,0 +1,58 @@ +package replicate + +// ModelList is a list of models that can be used with Replicate. +// +// https://replicate.com/pricing +var ModelList = []string{ + // ------------------------------------- + // image model + // ------------------------------------- + "black-forest-labs/flux-1.1-pro", + "black-forest-labs/flux-1.1-pro-ultra", + "black-forest-labs/flux-canny-dev", + "black-forest-labs/flux-canny-pro", + "black-forest-labs/flux-depth-dev", + "black-forest-labs/flux-depth-pro", + "black-forest-labs/flux-dev", + "black-forest-labs/flux-dev-lora", + "black-forest-labs/flux-fill-dev", + "black-forest-labs/flux-fill-pro", + "black-forest-labs/flux-pro", + "black-forest-labs/flux-redux-dev", + "black-forest-labs/flux-redux-schnell", + "black-forest-labs/flux-schnell", + "black-forest-labs/flux-schnell-lora", + "ideogram-ai/ideogram-v2", + "ideogram-ai/ideogram-v2-turbo", + "recraft-ai/recraft-v3", + "recraft-ai/recraft-v3-svg", + "stability-ai/stable-diffusion-3", + "stability-ai/stable-diffusion-3.5-large", + "stability-ai/stable-diffusion-3.5-large-turbo", + "stability-ai/stable-diffusion-3.5-medium", + // ------------------------------------- + // language model + // ------------------------------------- + "ibm-granite/granite-20b-code-instruct-8k", + "ibm-granite/granite-3.0-2b-instruct", + "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k", + "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3.1-405b-instruct", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1", + "mistralai/mixtral-8x7b-instruct-v0.1", + // ------------------------------------- + // video model + // ------------------------------------- + // "minimax/video-01", // TODO: implement the adaptor +} diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go new file mode 100644 index 00000000..3687249a --- /dev/null +++ b/relay/adaptor/replicate/image.go @@ -0,0 +1,222 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "image" + "image/png" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "golang.org/x/image/webp" + "golang.org/x/sync/errgroup" +) + +// ImagesEditsHandler just copy response body to client +// +// https://replicate.com/black-forest-labs/flux-fill-pro +// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// c.Writer.WriteHeader(resp.StatusCode) +// for k, v := range resp.Header { +// c.Writer.Header().Set(k, v[0]) +// } + +// if _, err := io.Copy(c.Writer, resp.Body); err != nil { +// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// } +// defer resp.Body.Close() + +// return nil, nil +// } + +var errNextLoop = errors.New("next_loop") + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ImageResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ImageResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed: %s", taskData.Status) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + output, err := taskData.GetOutput() + if err != nil { + return errors.Wrap(err, "get output") + } + if len(output) == 0 { + return errors.New("response output is empty") + } + + var mu sync.Mutex + var pool errgroup.Group + respBody := &openai.ImageResponse{ + Created: taskData.CompletedAt.Unix(), + Data: []openai.ImageData{}, + } + + for _, imgOut := range output { + imgOut := imgOut + pool.Go(func() error { + // download image + downloadReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, imgOut, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + imgResp, err := http.DefaultClient.Do(downloadReq) + if err != nil { + return errors.Wrap(err, "download image") + } + defer imgResp.Body.Close() + + if imgResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(imgResp.Body) + return errors.Errorf("bad status code [%d]%s", + imgResp.StatusCode, string(payload)) + } + + imgData, err := io.ReadAll(imgResp.Body) + if err != nil { + return errors.Wrap(err, "read image") + } + + imgData, err = ConvertImageToPNG(imgData) + if err != nil { + return errors.Wrap(err, "convert image") + } + + mu.Lock() + respBody.Data = append(respBody.Data, openai.ImageData{ + B64Json: fmt.Sprintf("data:image/png;base64,%s", + base64.StdEncoding.EncodeToString(imgData)), + }) + mu.Unlock() + + return nil + }) + } + + if err := pool.Wait(); err != nil { + if len(respBody.Data) == 0 { + return errors.WithStack(err) + } + + logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) + } + + c.JSON(http.StatusOK, respBody) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, nil +} + +// ConvertImageToPNG converts a WebP image to PNG format +func ConvertImageToPNG(webpData []byte) ([]byte, error) { + // bypass if it's already a PNG image + if bytes.HasPrefix(webpData, []byte("\x89PNG")) { + return webpData, nil + } + + // check if is jpeg, convert to png + if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { + img, _, err := image.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode jpeg") + } + + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil + } + + // Decode the WebP image + img, err := webp.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode webp") + } + + // Encode the image as PNG + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil +} diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go new file mode 100644 index 00000000..dba277eb --- /dev/null +++ b/relay/adaptor/replicate/model.go @@ -0,0 +1,159 @@ +package replicate + +import ( + "time" + + "github.com/pkg/errors" +) + +// DrawImageRequest draw image by fluxpro +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type DrawImageRequest struct { + Input ImageInput `json:"input"` +} + +// ImageInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema +type ImageInput struct { + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + ImagePrompt string `json:"image_prompt"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + Interval int `json:"interval" binding:"required,min=1,max=4"` + AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + Seed int `json:"seed"` + NImages int `json:"n_images" binding:"required,min=1,max=8"` + Width int `json:"width" binding:"required,min=256,max=1440"` + Height int `json:"height" binding:"required,min=256,max=1440"` +} + +// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type InpaintingImageByFlusReplicateRequest struct { + Input FluxInpaintingInput `json:"input"` +} + +// FluxInpaintingInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type FluxInpaintingInput struct { + Mask string `json:"mask" binding:"required"` + Image string `json:"image" binding:"required"` + Seed int `json:"seed"` + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + OutputFormat string `json:"output_format"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + PromptUnsampling bool `json:"prompt_unsampling"` +} + +// ImageResponse is response of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type ImageResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input DrawImageRequest `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output any `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs FluxURLs `json:"urls"` + Version string `json:"version"` +} + +func (r *ImageResponse) GetOutput() ([]string, error) { + switch v := r.Output.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + case nil: + return nil, nil + case []interface{}: + // convert []interface{} to []string + ret := make([]string, len(v)) + for idx, vv := range v { + if vvv, ok := vv.(string); ok { + ret[idx] = vvv + } else { + return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) + } + } + + return ret, nil + default: + return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) + } +} + +// FluxMetrics is metrics of ImageResponse +type FluxMetrics struct { + ImageCount int `json:"image_count"` + PredictTime float64 `json:"predict_time"` + TotalTime float64 `json:"total_time"` +} + +// FluxURLs is urls of ImageResponse +type FluxURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +type ReplicateChatRequest struct { + Input ChatInput `json:"input" form:"input" binding:"required"` +} + +// ChatInput is input of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema +type ChatInput struct { + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + MinTokens int `json:"min_tokens"` + Temperature float64 `json:"temperature"` + SystemPrompt string `json:"system_prompt"` + StopSequences string `json:"stop_sequences"` + PromptTemplate string `json:"prompt_template"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` +} + +// ChatResponse is response of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json +type ChatResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input ChatInput `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output []string `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs ChatResponseUrl `json:"urls"` + Version string `json:"version"` +} + +// ChatResponseUrl is task urls of ChatResponse +type ChatResponseUrl struct { + Stream string `json:"stream"` + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go index cf7b6a0d..0c6a5ff1 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -19,6 +19,7 @@ const ( DeepL VertexAI Proxy + Replicate Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index f7d862c7..613d2b31 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -48,8 +48,14 @@ var ModelRatio = map[string]float64{ "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens "text-ada-001": 0.2, "text-babbage-001": 0.25, "text-curie-001": 1, @@ -214,6 +220,50 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, // https://console.x.ai/ "grok-beta": 5.0 / 1000 * USD, + // replicate charges based on the number of generated images + // https://replicate.com/pricing + "black-forest-labs/flux-1.1-pro": 0.04 * USD, + "black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD, + "black-forest-labs/flux-canny-dev": 0.025 * USD, + "black-forest-labs/flux-canny-pro": 0.05 * USD, + "black-forest-labs/flux-depth-dev": 0.025 * USD, + "black-forest-labs/flux-depth-pro": 0.05 * USD, + "black-forest-labs/flux-dev": 0.025 * USD, + "black-forest-labs/flux-dev-lora": 0.032 * USD, + "black-forest-labs/flux-fill-dev": 0.04 * USD, + "black-forest-labs/flux-fill-pro": 0.05 * USD, + "black-forest-labs/flux-pro": 0.055 * USD, + "black-forest-labs/flux-redux-dev": 0.025 * USD, + "black-forest-labs/flux-redux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell-lora": 0.02 * USD, + "ideogram-ai/ideogram-v2": 0.08 * USD, + "ideogram-ai/ideogram-v2-turbo": 0.05 * USD, + "recraft-ai/recraft-v3": 0.04 * USD, + "recraft-ai/recraft-v3-svg": 0.08 * USD, + "stability-ai/stable-diffusion-3": 0.035 * USD, + "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, + "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, + "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, + // replicate chat models + "ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD, + "ibm-granite/granite-3.0-2b-instruct": 0.030 * USD, + "ibm-granite/granite-3.0-8b-instruct": 0.050 * USD, + "ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, + "meta/llama-2-13b": 0.100 * USD, + "meta/llama-2-13b-chat": 0.100 * USD, + "meta/llama-2-70b": 0.650 * USD, + "meta/llama-2-70b-chat": 0.650 * USD, + "meta/llama-2-7b": 0.050 * USD, + "meta/llama-2-7b-chat": 0.050 * USD, + "meta/meta-llama-3.1-405b-instruct": 9.500 * USD, + "meta/meta-llama-3-70b": 0.650 * USD, + "meta/meta-llama-3-70b-instruct": 0.650 * USD, + "meta/meta-llama-3-8b": 0.050 * USD, + "meta/meta-llama-3-8b-instruct": 0.050 * USD, + "mistralai/mistral-7b-instruct-v0.2": 0.050 * USD, + "mistralai/mistral-7b-v0.1": 0.050 * USD, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, } var CompletionRatio = map[string]float64{ @@ -347,6 +397,10 @@ func GetCompletionRatio(name string, channelType int) float64 { } return 2 } + // including o1, o1-preview, o1-mini + if strings.HasPrefix(name, "o1") { + return 4 + } if name == "chatgpt-4o-latest" { return 3 } @@ -365,6 +419,7 @@ func GetCompletionRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "deepseek-") { return 2 } + switch name { case "llama2-70b-4096": return 0.8 / 0.64 @@ -380,6 +435,35 @@ func GetCompletionRatio(name string, channelType int) float64 { return 5 case "grok-beta": return 3 + // Replicate Models + // https://replicate.com/pricing + case "ibm-granite/granite-20b-code-instruct-8k": + return 5 + case "ibm-granite/granite-3.0-2b-instruct": + return 8.333333333333334 + case "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k": + return 5 + case "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct": + return 5 + case "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct": + return 2.750 / 0.650 // ≈4.230769 + case "meta/meta-llama-3.1-405b-instruct": + return 1 + case "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1": + return 5 + case "mistralai/mixtral-8x7b-instruct-v0.1": + return 1.000 / 0.300 // ≈3.333333 } + return 1 } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 98316959..f54d0e30 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -47,5 +47,6 @@ const ( Proxy SiliconFlow XAI + Replicate Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index fae3357f..8839b30a 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { apiType = apitype.DeepL case VertextAI: apiType = apitype.VertexAI + case Replicate: + apiType = apitype.Replicate case Proxy: apiType = apitype.Proxy } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index b8bd61f8..8e271f4e 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{ "", // 43 "https://api.siliconflow.cn", // 44 "https://api.x.ai", // 45 + "https://api.replicate.com/v1/models/", // 46 } func init() { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 567dee7c..5f5fc90c 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -147,14 +147,20 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return true } - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK && + // replicate return 201 to create a task + resp.StatusCode != http.StatusCreated { return true } if meta.ChannelType == channeltype.DeepL { // skip stream check for deepl return false } - if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && + // Even if stream mode is enabled, replicate will first return a task info in JSON format, + // requiring the client to request the stream endpoint in the task info + meta.ChannelType != channeltype.Replicate { return true } return false diff --git a/relay/controller/image.go b/relay/controller/image.go index 1e06e858..1b69d97d 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -22,7 +22,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { +func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { @@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -150,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } adaptor.Init(meta) + // these adaptors need to convert the request switch meta.ChannelType { - case channeltype.Ali: - fallthrough - case channeltype.Baidu: - fallthrough - case channeltype.Zhipu: + case channeltype.Zhipu, + channeltype.Ali, + channeltype.Replicate, + channeltype.Baidu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) @@ -172,7 +172,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -186,7 +193,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } defer func(ctx context.Context) { - if resp != nil && resp.StatusCode != http.StatusOK { + if resp != nil && + resp.StatusCode != http.StatusCreated && // replicate returns 201 + resp.StatusCode != http.StatusOK { return } diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index a7e984ec..e7b25399 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 35398875..375adcd9 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = { value: 45, color: 'primary' }, + 45: { + key: 46, + text: 'Replicate', + value: 46, + color: 'primary' + }, 41: { key: 41, text: 'Novita', diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index f9c2896c..bd85f8bf 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -95,7 +95,7 @@ export async function onLarkOAuthClicked(lark_client_id) { const state = await getOAuthState(); if (!state) return; let redirect_uri = `${window.location.origin}/oauth/lark`; - window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); + window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`); } export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { diff --git a/web/build.sh b/web/build.sh old mode 100644 new mode 100755 diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 5b25577d..61425508 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },