diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index dd688493..0ddd31a8 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,26 +1,27 @@ --- -name: 报告问题 -about: 使用简练详细的语言描述你遇到的问题 -title: '' +name: Bug Report +about: Use concise and detailed language to describe the issue you encountered +title: "" labels: bug -assignees: '' - +assignees: "" --- -**例行检查** +## Routine Check -[//]: # (方框内删除已有的空格,填 x 号) -+ [ ] 我已确认目前没有类似 issue -+ [ ] 我已确认我已升级到最新版本 -+ [ ] 我已完整查看过项目 README,尤其是常见问题部分 -+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 -+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** +[//]: # "Remove space in brackets and fill with x" -**问题描述** +- [ ] I have confirmed there are no similar issues +- [ ] I have confirmed I am using the latest version +- [ ] I have thoroughly read the project README, especially the FAQ section +- [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback +- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed** -**复现步骤** +## Issue Description -**预期结果** +## Steps to Reproduce -**相关截图** -如果没有的话,请删除此节。 \ No newline at end of file +## Expected Behavior + +## Screenshots + +Delete this section if not applicable. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index 83a0f3f4..00000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,8 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: 项目群聊 - url: https://openai.justsong.cn/ - about: QQ 群:828520184,自动审核,备注 One API - - name: 赞赏支持 - url: https://iamazing.cn/page/reward - about: 请作者喝杯咖啡,以激励作者持续开发 diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 049d89c8..c688f9d1 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,21 +1,21 @@ --- -name: 功能请求 -about: 使用简练详细的语言描述希望加入的新功能 -title: '' +name: Feature Request +about: Use concise and detailed language to describe the new feature you'd like to see +title: "" labels: enhancement -assignees: '' - +assignees: "" --- -**例行检查** +## Routine Check -[//]: # (方框内删除已有的空格,填 x 号) -+ [ ] 我已确认目前没有类似 issue -+ [ ] 我已确认我已升级到最新版本 -+ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求 -+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 -+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** +[//]: # "Remove space in brackets and fill with x" -**功能描述** +- [ ] I have confirmed there are no similar issues +- [ ] I have confirmed I am using the latest version +- [ ] I have thoroughly read the project README and confirmed existing features cannot meet my needs +- [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback +- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed** -**应用场景** +## Feature Description + +## Use Cases diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36798711..035abb6d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,47 +1,142 @@ -name: CI +name: ci -# This setup assumes that you run the unit tests with code coverage in the same -# workflow that will also print the coverage report as comment to the pull request. -# Therefore, you need to trigger this workflow when a pull request is (re)opened or -# when new code is pushed to the branch of the pull request. In addition, you also -# need to trigger this workflow when new code is pushed to the main branch because -# we need to upload the code coverage results as artifact for the main branch as -# well since it will be the baseline code coverage. -# -# We do not want to trigger the workflow for pushes to *any* branch because this -# would trigger our jobs twice on pull requests (once from "push" event and once -# from "pull_request->synchronize") on: - pull_request: - types: [opened, reopened, synchronize] push: branches: - - 'main' + - "master" + - "main" + - "test/ci" jobs: - unit_tests: - name: "Unit tests" + build_latest: runs-on: ubuntu-latest steps: - - name: Checkout repository + - name: Checkout uses: actions/checkout@v4 - - name: Setup Go - uses: actions/setup-go@v4 - with: - go-version: ^1.22 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a - # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") - # in the next step as well as the next job. - - name: Test - run: go test -cover -coverprofile=coverage.txt ./... - - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - commit_lint: + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push latest + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ppcelery/one-api:latest + cache-from: type=gha + # cache-to: type=gha,mode=max + + build_hash: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: wagoid/commitlint-github-action@v6 + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Add SHORT_SHA env property with commit short sha + run: echo "SHORT_SHA=`echo ${GITHUB_SHA} | cut -c1-7`" >> $GITHUB_ENV + + - name: Build and push hash label + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ppcelery/one-api:${{ env.SHORT_SHA }} + cache-from: type=gha + cache-to: type=gha,mode=max + + deploy: + runs-on: ubuntu-latest + needs: build_latest + steps: + - name: executing remote ssh commands using password + uses: appleboy/ssh-action@v1.0.3 + with: + host: ${{ secrets.TARGET_HOST }} + username: ${{ secrets.TARGET_HOST_USERNAME }} + password: ${{ secrets.TARGET_HOST_PASSWORD }} + port: ${{ secrets.TARGET_HOST_SSH_PORT }} + script: | + docker pull ppcelery/one-api:latest + cd /home/laisky/repo/VPS + docker-compose -f b1-docker-compose.yml up -d --remove-orphans --force-recreate oneapi + docker ps + + build_arm64_hash: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Add SHORT_SHA env property with commit short sha + run: echo "SHORT_SHA=`echo ${GITHUB_SHA} | cut -c1-7`" >> $GITHUB_ENV + + - name: Build and push arm64 hash label + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ppcelery/one-api:arm64-${{ env.SHORT_SHA }} + platforms: linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max + + build_arm64_latest: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push arm64 latest + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ppcelery/one-api:arm64-latest + platforms: linux/arm64 + cache-from: type=gha + # cache-to: type=gha,mode=max diff --git a/.github/workflows/docker-image-en.yml b/.github/workflows/docker-image-en.yml deleted file mode 100644 index 30cd0e38..00000000 --- a/.github/workflows/docker-image-en.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: Publish Docker image (English) - -on: - push: - tags: - - 'v*.*.*' - workflow_dispatch: - inputs: - name: - description: 'reason' - required: false -jobs: - push_to_registries: - name: Push Docker image to multiple registries - runs-on: ubuntu-latest - permissions: - packages: write - contents: read - steps: - - name: Check out the repo - uses: actions/checkout@v3 - - - name: Check repository URL - run: | - REPO_URL=$(git config --get remote.origin.url) - if [[ $REPO_URL == *"pro" ]]; then - exit 1 - fi - - - name: Save version info - run: | - git describe --tags > VERSION - - - name: Translate - run: | - python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - - name: Log in to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Extract metadata (tags, labels) for Docker - id: meta - uses: docker/metadata-action@v4 - with: - images: | - justsong/one-api-en - - - name: Build and push Docker images - uses: docker/build-push-action@v3 - with: - context: . - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml deleted file mode 100644 index 56f1d6ad..00000000 --- a/.github/workflows/docker-image.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Publish Docker image - -on: - push: - tags: - - 'v*.*.*' - workflow_dispatch: - inputs: - name: - description: 'reason' - required: false -jobs: - push_to_registries: - name: Push Docker image to multiple registries - runs-on: ubuntu-latest - permissions: - packages: write - contents: read - steps: - - name: Check out the repo - uses: actions/checkout@v3 - - - name: Check repository URL - run: | - REPO_URL=$(git config --get remote.origin.url) - if [[ $REPO_URL == *"pro" ]]; then - exit 1 - fi - - - name: Save version info - run: | - git describe --tags > VERSION - - - name: Set up QEMU - uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 - - - name: Log in to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Log in to the Container registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract metadata (tags, labels) for Docker - id: meta - uses: docker/metadata-action@v4 - with: - images: | - justsong/one-api - ghcr.io/${{ github.repository }} - - - name: Build and push Docker images - uses: docker/build-push-action@v3 - with: - context: . - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..460fe33d --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,53 @@ +name: lint + +on: + push: + branches: + - "master" + - "main" + - "test/ci" + +jobs: + unit_tests: + name: "Unit tests" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: ^1.23 + + - name: Install ffmpeg + run: sudo apt-get update && sudo apt-get install -y ffmpeg + + # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a + # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") + # in the next step as well as the next job. + - name: Test + run: go test -cover -coverprofile=coverage.txt ./... + + - name: Archive code coverage results + uses: actions/upload-artifact@v4 + with: + name: code-coverage + path: coverage.txt # Make sure to use the same file name you chose for the "-coverprofile" in the "Test" step + + code_coverage: + name: "Code coverage report" + if: github.event_name == 'pull_request' # Do not run when workflow is triggered by push to main branch + runs-on: ubuntu-latest + needs: unit_tests # Depends on the artifact uploaded by the "unit_tests" job + steps: + - uses: fgrosse/go-coverage-report@v1.0.2 # Consider using a Git revision for maximum security + with: + coverage-artifact-name: "code-coverage" # can be omitted if you used this default value + coverage-file-name: "coverage.txt" # can be omitted if you used this default value + + commit_lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: wagoid/commitlint-github-action@v6 diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index 161c41e3..5e12fc9f 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -62,4 +62,4 @@ jobs: draft: true generate_release_notes: true env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 00000000..5c7617d4 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,35 @@ +name: pr-check + +on: + pull_request: + branches: + - "master" + - "main" + +jobs: + build_latest: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: try to build + uses: docker/build-push-action@v5 + with: + context: . + push: false + tags: ppcelery/one-api:pr + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 18641ae8..939efd8d 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -56,4 +56,4 @@ jobs: draft: true generate_release_notes: true env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 4e431e65..e7f1e0a2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ build *.db-journal logs data +node_modules /web/node_modules cmd.md -.env \ No newline at end of file +.env +/one-api diff --git a/Dockerfile b/Dockerfile index 0f00d092..052547ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM node:16 AS builder +FROM node:18 as builder WORKDIR /web COPY ./VERSION . @@ -16,9 +16,12 @@ WORKDIR /web/air RUN npm install RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat ../VERSION) npm run build -FROM golang:alpine AS builder2 +FROM golang:1.23.3-bullseye AS builder2 -RUN apk add --no-cache g++ +RUN apt-get update +RUN apt-get install -y --no-install-recommends g++ make gcc git build-essential ca-certificates \ + && update-ca-certificates 2>/dev/null || true \ + && rm -rf /var/lib/apt/lists/* ENV GO111MODULE=on \ CGO_ENABLED=1 \ @@ -31,14 +34,14 @@ COPY . . COPY --from=builder /web/build ./web/build RUN go build -trimpath -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api -FROM alpine +FROM debian:bullseye -RUN apk update \ - && apk upgrade \ - && apk add --no-cache ca-certificates tzdata \ - && update-ca-certificates 2>/dev/null || true +RUN apt-get update +RUN apt-get install -y --no-install-recommends ca-certificates haveged tzdata ffmpeg \ + && update-ca-certificates 2>/dev/null || true \ + && rm -rf /var/lib/apt/lists/* COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data -ENTRYPOINT ["/one-api"] \ No newline at end of file +ENTRYPOINT ["/one-api"] diff --git a/README.md b/README.md index 5f9947b0..638e0f25 100644 --- a/README.md +++ b/README.md @@ -1,463 +1,117 @@ -

- 中文 | English | 日本語 -

- - -

- one-api logo -

- -
- # One API -_✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 ✨_ +The original author of one-api has not been active for a long time, resulting in a backlog of PRs that cannot be updated. Therefore, I forked the code and merged some PRs that I consider important. I also welcome everyone to submit PRs, and I will respond and handle them actively and quickly. -
+Fully compatible with the upstream version, can be used directly by replacing the container image, docker images: -

- - license - - - release - - - docker pull - - - release - - - GoReportCard - -

+- `ppcelery/one-api:latest` +- `ppcelery/one-api:arm64-latest` -

- 部署教程 - · - 使用方法 - · - 意见反馈 - · - 截图展示 - · - 在线演示 - · - 常见问题 - · - 相关项目 - · - 赞赏支持 -

+## Menu -> [!NOTE] -> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 -> -> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 +- [One API](#one-api) + - [Menu](#menu) + - [New Features](#new-features) + - [(Merged) Support gpt-vision](#merged-support-gpt-vision) + - [Support update user's remained quota](#support-update-users-remained-quota) + - [(Merged) Support aws claude](#merged-support-aws-claude) + - [Support openai images edits](#support-openai-images-edits) + - [Support gemini-2.0-flash-exp](#support-gemini-20-flash-exp) + - [Support replicate flux \& remix](#support-replicate-flux--remix) + - [Support replicate chat models](#support-replicate-chat-models) + - [Support OpenAI O1/O1-mini/O1-preview](#support-openai-o1o1-minio1-preview) + - [Get request's cost](#get-requests-cost) + - [Support Vertex Imagen3](#support-vertex-imagen3) + - [Support gpt-4o-audio](#support-gpt-4o-audio) + - [Bug fix](#bug-fix) + - [The token balance cannot be edited](#the-token-balance-cannot-be-edited) + - [Whisper's transcription only charges for the length of the input audio](#whispers-transcription-only-charges-for-the-length-of-the-input-audio) -> [!WARNING] -> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 +## New Features -> [!WARNING] -> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! +### (Merged) Support gpt-vision -## 功能 -1. 支持多种大模型: - + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) - + [x] [Anthropic Claude 系列模型](https://anthropic.com) (支持 AWS Claude) - + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) - + [x] [Mistral 系列模型](https://mistral.ai/) - + [x] [字节跳动豆包大模型](https://console.volcengine.com/ark/region:ark+cn-beijing/model) - + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) - + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) - + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) - + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) - + [x] [360 智脑](https://ai.360.cn) - + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) - + [x] [Moonshot AI](https://platform.moonshot.cn/) - + [x] [百川大模型](https://platform.baichuan-ai.com) - + [x] [MINIMAX](https://api.minimax.chat/) - + [x] [Groq](https://wow.groq.com/) - + [x] [Ollama](https://github.com/ollama/ollama) - + [x] [零一万物](https://platform.lingyiwanwu.com/) - + [x] [阶跃星辰](https://platform.stepfun.com/) - + [x] [Coze](https://www.coze.com/) - + [x] [Cohere](https://cohere.com/) - + [x] [DeepSeek](https://www.deepseek.com/) - + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) - + [x] [DeepL](https://www.deepl.com/) - + [x] [together.ai](https://www.together.ai/) - + [x] [novita.ai](https://www.novita.ai/) - + [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud) -2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 -3. 支持通过**负载均衡**的方式访问多个渠道。 -4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 -5. 支持**多机部署**,[详见此处](#多机部署)。 -6. 支持**令牌管理**,设置令牌的过期时间、额度、允许的 IP 范围以及允许的模型访问。 -7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。 -8. 支持**渠道管理**,批量创建渠道。 -9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。 -10. 支持渠道**设置模型列表**。 -11. 支持**查看额度明细**。 -12. 支持**用户邀请奖励**。 -13. 支持以美元为单位显示额度。 -14. 支持发布公告,设置充值链接,设置新用户初始额度。 -15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 -16. 支持失败自动重试。 -17. 支持绘图接口。 -18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 -19. 支持丰富的**自定义**设置, - 1. 支持自定义系统名称,logo 以及页脚。 - 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。。 -21. 支持 Cloudflare Turnstile 用户校验。 -22. 支持用户管理,支持**多种用户登录注册方式**: - + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 - + 支持使用飞书进行授权登录。 - + [GitHub 开放授权](https://github.com/settings/applications/new)。 - + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 -23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 -24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 +### Support update user's remained quota -## 部署 -### 基于 Docker 进行部署 -```shell -# 使用 SQLite 的部署命令: -docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api -# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 -# 例如: -docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api -``` +You can update the used quota using the API key of any token, allowing other consumption to be aggregated into the one-api for centralized management. -其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 +![](https://s3.laisky.com/uploads/2024/12/oneapi-update-quota.png) -数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +### (Merged) Support aws claude -如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 +- [feat: support aws bedrockruntime claude3 #1328](https://github.com/songquanpeng/one-api/pull/1328) +- [feat: add new claude models #1910](https://github.com/songquanpeng/one-api/pull/1910) -如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 +![](https://s3.laisky.com/uploads/2024/12/oneapi-claude.png) -如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 +### Support openai images edits -更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` +- [feat: support openai images edits api #1369](https://github.com/songquanpeng/one-api/pull/1369) -Nginx 的参考配置: -``` -server{ - server_name openai.justsong.cn; # 请根据实际情况修改你的域名 +![](https://s3.laisky.com/uploads/2024/12/oneapi-image-edit.png) - location / { - client_max_body_size 64m; - proxy_http_version 1.1; - proxy_pass http://localhost:3000; # 请根据实际情况修改你的端口 - proxy_set_header Host $host; - proxy_set_header X-Forwarded-For $remote_addr; - proxy_cache_bypass $http_upgrade; - proxy_set_header Accept-Encoding gzip; - proxy_read_timeout 300s; # GPT-4 需要较长的超时时间,请自行调整 - } +### Support gemini-2.0-flash-exp + +- [feat: add gemini-2.0-flash-exp #1983](https://github.com/songquanpeng/one-api/pull/1983) + +![](https://s3.laisky.com/uploads/2024/12/oneapi-gemini-flash.png) + +### Support replicate flux & remix + +- [feature: 支持 replicate 的绘图 #1954](https://github.com/songquanpeng/one-api/pull/1954) +- [feat: image edits/inpaiting 支持 replicate 的 flux remix #1986](https://github.com/songquanpeng/one-api/pull/1986) + +![](https://s3.laisky.com/uploads/2024/12/oneapi-replicate-1.png) + +![](https://s3.laisky.com/uploads/2024/12/oneapi-replicate-2.png) + +![](https://s3.laisky.com/uploads/2024/12/oneapi-replicate-3.png) + +### Support replicate chat models + +- [feat: 支持 replicate chat models #1989](https://github.com/songquanpeng/one-api/pull/1989) + +### Support OpenAI O1/O1-mini/O1-preview + +- [feat: add openai o1 #1990](https://github.com/songquanpeng/one-api/pull/1990) + +### Get request's cost + +Each chat completion request will include a `X-Oneapi-Request-Id` in the returned headers. You can use this request id to request `GET /api/cost/request/:request_id` to get the cost of this request. + +The returned structure is: + +```go +type UserRequestCost struct { + Id int `json:"id"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UserID int `json:"user_id"` + RequestID string `json:"request_id"` + Quota int64 `json:"quota"` + CostUSD float64 `json:"cost_usd" gorm:"-"` } ``` -之后使用 Let's Encrypt 的 certbot 配置 HTTPS: -```bash -# Ubuntu 安装 certbot: -sudo snap install --classic certbot -sudo ln -s /snap/bin/certbot /usr/bin/certbot -# 生成证书 & 修改 Nginx 配置 -sudo certbot --nginx -# 根据指示进行操作 -# 重启 Nginx -sudo service nginx restart -``` +### Support Vertex Imagen3 -初始账号用户名为 `root`,密码为 `123456`。 +- [feat: support vertex imagen3 #2030](https://github.com/songquanpeng/one-api/pull/2030) +![](https://s3.laisky.com/uploads/2025/01/oneapi-imagen3.png) -### 基于 Docker Compose 进行部署 +### Support gpt-4o-audio -> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 +- [feat: support gpt-4o-audio #2032](https://github.com/songquanpeng/one-api/pull/2032) -```shell -# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 -docker-compose up -d +![](https://s3.laisky.com/uploads/2025/01/oneapi-audio-1.png) -# 查看部署状态 -docker-compose ps -``` +![](https://s3.laisky.com/uploads/2025/01/oneapi-audio-2.png) -### 手动部署 -1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: - ```shell - git clone https://github.com/songquanpeng/one-api.git +## Bug fix - # 构建前端 - cd one-api/web/default - npm install - npm run build +### The token balance cannot be edited - # 构建后端 - cd ../.. - go mod download - go build -ldflags "-s -w" -o one-api - ```` -2. 运行: - ```shell - chmod u+x one-api - ./one-api --port 3000 --log-dir ./logs - ``` -3. 访问 [http://localhost:3000/](http://localhost:3000/) 并登录。初始账号用户名为 `root`,密码为 `123456`。 +- [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933) -更加详细的部署教程[参见此处](https://iamazing.cn/page/how-to-deploy-a-website)。 +### Whisper's transcription only charges for the length of the input audio -### 多机部署 -1. 所有服务器 `SESSION_SECRET` 设置一样的值。 -2. 必须设置 `SQL_DSN`,使用 MySQL 数据库而非 SQLite,所有服务器连接同一个数据库。 -3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 -4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis,无论主从。 -5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 -6. 从服务器上**分别**装好 Redis,设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。 -7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis,并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 - -环境变量的具体使用方法详见[此处](#环境变量)。 - -### 宝塔部署教程 - -详见 [#175](https://github.com/songquanpeng/one-api/issues/175)。 - -如果部署后访问出现空白页面,详见 [#97](https://github.com/songquanpeng/one-api/issues/97)。 - -### 部署第三方服务配合 One API 使用 -> 欢迎 PR 添加更多示例。 - -#### ChatGPT Next Web -项目主页:https://github.com/Yidadaa/ChatGPT-Next-Web - -```bash -docker run --name chat-next-web -d -p 3001:3000 yidadaa/chatgpt-next-web -``` - -注意修改端口号,之后在页面上设置接口地址(例如:https://openai.justsong.cn/ )和 API Key 即可。 - -#### ChatGPT Web -项目主页:https://github.com/Chanzhaoyu/chatgpt-web - -```bash -docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://openai.justsong.cn -e OPENAI_API_KEY=sk-xxx chenzhaoyu94/chatgpt-web -``` - -注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 - -#### QChatGPT - QQ机器人 -项目主页:https://github.com/RockChinQ/QChatGPT - -根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 - -运行期间可以通过`!model`命令查看、切换可用模型。 - -### 部署到第三方平台 -
-部署到 Sealos -
- -> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩。 - -点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟): - -[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api) - -
-
- -
-部署到 Zeabur -
- -> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 - -[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) - -1. 首先 fork 一份代码。 -2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 -3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 -4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 -5. 然后在 Service -> Add Service,选择 Git(第一次使用需要先授权),选择你 fork 的仓库。 -6. Deploy 会自动开始,先取消。进入下方 Variable,添加一个 `PORT`,值为 `3000`,再添加一个 `SQL_DSN`,值为 `:@tcp(:)/one-api` ,然后保存。 注意如果不填写 `SQL_DSN`,数据将无法持久化,重新部署后数据会丢失。 -7. 选择 Redeploy。 -8. 进入下方 Domains,选择一个合适的域名前缀,如 "my-one-api",最终域名为 "my-one-api.zeabur.app",也可以 CNAME 自己的域名。 -9. 等待部署完成,点击生成的域名进入 One API。 - -
-
- -
-部署到 Render -
- -> Render 提供免费额度,绑卡后可以进一步提升额度 - -Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com - -
-
- -## 配置 -系统本身开箱即用。 - -你可以通过设置环境变量或者命令行参数进行配置。 - -等到系统启动后,使用 `root` 用户登录系统并做进一步的配置。 - -**Note**:如果你不知道某个配置项的含义,可以临时删掉值以看到进一步的提示文字。 - -## 使用方法 -在`渠道`页面中添加你的 API Key,之后在`令牌`页面中新增访问令牌。 - -之后就可以使用你的令牌访问 One API 了,使用方式与 [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) 一致。 - -你需要在各种用到 OpenAI API 的地方设置 API Base 为你的 One API 的部署地址,例如:`https://openai.justsong.cn`,API Key 则为你在 One API 中生成的令牌。 - -注意,具体的 API Base 的格式取决于你所使用的客户端。 - -例如对于 OpenAI 的官方库: -```bash -OPENAI_API_KEY="sk-xxxxxx" -OPENAI_API_BASE="https://:/v1" -``` - -```mermaid -graph LR - A(用户) - A --->|使用 One API 分发的 key 进行请求| B(One API) - B -->|中继请求| C(OpenAI) - B -->|中继请求| D(Azure) - B -->|中继请求| E(其他 OpenAI API 格式下游渠道) - B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) -``` - -可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 -注意,需要是管理员用户创建的令牌才能指定渠道 ID。 - -不加的话将会使用负载均衡的方式使用多个渠道。 - -### 环境变量 -> One API 支持从 `.env` 文件中读取环境变量,请参照 `.env.example` 文件,使用时请将其重命名为 `.env`。 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 - + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` - + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。 -2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 - + 例子:`SESSION_SECRET=random_string` -3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。 - + 例子: - + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi` - + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈) - + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 - + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 - + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 - + 请根据你的数据库配置修改下列参数(或者保持默认值): - + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。 - + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 - + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 - + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 -4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 -5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 - + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 - + 例子:`MEMORY_CACHE_ENABLED=true` -7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 - + 例子:`SYNC_FREQUENCY=60` -8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 - + 例子:`NODE_TYPE=slave` -9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 - + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 - +例子:`CHANNEL_TEST_FREQUENCY=1440` -11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 - + 例子:`POLLING_INTERVAL=5` -12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 - + 例子:`BATCH_UPDATE_ENABLED=true` - + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 - + 例子:`BATCH_UPDATE_INTERVAL=5` -14. 请求频率限制: - + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 - + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -15. 编码器缓存设置: - + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 - + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -17. `RELAY_PROXY`:设置后使用该代理来请求 API。 -18. `USER_CONTENT_REQUEST_TIMEOUT`:用户上传内容下载超时时间,单位为秒。 -19. `USER_CONTENT_REQUEST_PROXY`:设置后使用该代理来请求用户上传的内容,例如图片。 -20. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -21. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -22. `GEMINI_VERSION`:One API 所使用的 Gemini 版本,默认为 `v1`。 -23. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -24. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -25. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 -27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 -28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 - -### 命令行参数 -1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 - + 例子:`--port 3000` -2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 - + 例子:`--log-dir ./logs` -3. `--version`: 打印系统版本号并退出。 -4. `--help`: 查看命令的使用帮助和参数说明。 - -## 演示 -### 在线演示 -注意,该演示站不提供对外服务: -https://openai.justsong.cn - -### 截图展示 -![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png) -![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png) - -## 常见问题 -1. 额度是什么?怎么计算的?One API 的额度计算有问题? - + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) - + 其中补全倍率对于 GPT3.5 固定为 1.33,GPT4 为 2,与官方保持一致。 - + 如果是非流模式,官方接口会返回消耗的总 token,但是你要注意提示和补全的消耗倍率不一样。 - + 注意,One API 的默认倍率就是官方倍率,是已经调整过的。 -2. 账户额度足够为什么提示额度不足? - + 请检查你的令牌额度是否足够,这个和账户额度是分开的。 - + 令牌额度仅供用户设置最大使用量,用户可自由设置。 -3. 提示无可用渠道? - + 请检查的用户分组和渠道分组设置。 - + 以及渠道的模型设置。 -4. 渠道测试报错:`invalid character '<' looking for beginning of value` - + 这是因为返回值不是合法的 JSON,而是一个 HTML 页面。 - + 大概率是你的部署站的 IP 或代理的节点被 CloudFlare 封禁了。 -5. ChatGPT Next Web 报错:`Failed to fetch` - + 部署的时候不要设置 `BASE_URL`。 - + 检查你的接口地址和 API Key 有没有填对。 - + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 -6. 报错:`当前分组负载已饱和,请稍后再试` - + 上游渠道 429 了。 -7. 升级之后我的数据会丢失吗? - + 如果使用 MySQL,不会。 - + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 -8. 升级之前数据库需要做变更吗? - + 一般情况下不需要,系统将在初始化的时候自动调整。 - + 如果需要的话,我会在更新日志中说明,并给出脚本。 -9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`? - + 这是检测到 ability 表里有些记录的渠道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的渠道。 - + 对于每一个渠道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该渠道支持该模型。 - -## 相关项目 -* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 -* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用 -* [VChart](https://github.com/VisActor/VChart): 不只是开箱即用的多端图表库,更是生动灵活的数据故事讲述者。 -* [VMind](https://github.com/VisActor/VMind): 不仅自动,还很智能。开源智能可视化解决方案。 - -## 注意 - -本项目使用 MIT 协议进行开源,**在此基础上**,必须在页面底部保留署名以及指向本项目的链接。如果不想保留署名,必须首先获得授权。 - -同样适用于基于本项目的二开项目。 - -依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 +- [feat(audio): count whisper-1 quota by audio duration #2022](https://github.com/songquanpeng/one-api/pull/2022) diff --git a/common/client/init.go b/common/client/init.go index f803cbf8..0bd29ef7 100644 --- a/common/client/init.go +++ b/common/client/init.go @@ -2,11 +2,12 @@ package client import ( "fmt" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" "net/http" "net/url" "time" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" ) var HTTPClient *http.Client diff --git a/common/config/config.go b/common/config/config.go index 11da0b96..a9e56b07 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,16 +1,30 @@ package config import ( - "github.com/songquanpeng/one-api/common/env" + "crypto/rand" + "encoding/base64" + "fmt" "os" "strconv" "strings" "sync" "time" - "github.com/google/uuid" + "github.com/songquanpeng/one-api/common/env" ) +func init() { + if SessionSecret == "" { + fmt.Println("SESSION_SECRET not set, using random secret") + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + panic(fmt.Sprintf("failed to generate random secret: %v", err)) + } + + SessionSecret = base64.StdEncoding.EncodeToString(key) + } +} + var SystemName = "One API" var ServerAddress = "http://localhost:3000" var Footer = "" @@ -23,7 +37,7 @@ var DisplayTokenStatEnabled = true // Any options with "Secret", "Token" in its key won't be return by GetOptions -var SessionSecret = uuid.New().String() +var SessionSecret = os.Getenv("SESSION_SECRET") var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex @@ -35,6 +49,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -70,6 +85,13 @@ var GitHubClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" @@ -104,6 +126,7 @@ var BatchUpdateEnabled = false var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second +var IdleTimeout = env.Int("IDLE_TIMEOUT", 30) // unit is second var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") @@ -152,3 +175,5 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false) var RelayProxy = env.String("RELAY_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) + +var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 90556b3a..02868e59 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -3,9 +3,12 @@ package ctxkey const ( Config = "config" Id = "id" + RequestId = "X-Oneapi-Request-Id" Username = "username" Role = "role" Status = "status" + ChannelModel = "channel_model" + ChannelRatio = "channel_ratio" Channel = "channel" ChannelId = "channel_id" SpecificChannelId = "specific_channel_id" @@ -15,9 +18,12 @@ const ( Group = "group" ModelMapping = "model_mapping" ChannelName = "channel_name" + ContentType = "content_type" TokenId = "token_id" TokenName = "token_name" BaseURL = "base_url" AvailableModels = "available_models" KeyRequestBody = "key_request_body" + SystemPrompt = "system_prompt" + Meta = "meta" ) diff --git a/common/gin.go b/common/gin.go index 549d3279..c2451c88 100644 --- a/common/gin.go +++ b/common/gin.go @@ -3,41 +3,51 @@ package common import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "io" + "reflect" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" ) -func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(ctxkey.KeyRequestBody) - if requestBody != nil { - return requestBody.([]byte), nil +func GetRequestBody(c *gin.Context) (requestBody []byte, err error) { + if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil { + return requestBodyCache.([]byte), nil } - requestBody, err := io.ReadAll(c.Request.Body) + requestBody, err = io.ReadAll(c.Request.Body) if err != nil { - return nil, err + return nil, errors.Wrap(err, "read request body failed") } _ = c.Request.Body.Close() c.Set(ctxkey.KeyRequestBody, requestBody) - return requestBody.([]byte), nil + + return requestBody, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { requestBody, err := GetRequestBody(c) if err != nil { - return err + return errors.Wrap(err, "get request body failed") } + + // check v should be a pointer + if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr { + return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v)) + } + contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) + err = json.Unmarshal(requestBody, v) } else { - // skip for now - // TODO: someday non json request have variant model, we will need to implementation this + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = c.ShouldBind(v) } if err != nil { - return err + return errors.Wrap(err, "unmarshal request body failed") } + // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil diff --git a/common/helper/audio.go b/common/helper/audio.go new file mode 100644 index 00000000..146d9412 --- /dev/null +++ b/common/helper/audio.go @@ -0,0 +1,63 @@ +package helper + +import ( + "bytes" + "context" + "io" + "math" + "os" + "os/exec" + "strconv" + + "github.com/pkg/errors" +) + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + if data == nil { + return "", errors.New("data is nil") + } + + f, err := os.CreateTemp("", "*-"+filename) + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary file %s", filename) + } + defer f.Close() + + _, err = io.Copy(f, data) + if err != nil { + return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) + } + + return f.Name(), nil +} + +// GetAudioTokens returns the number of tokens in an audio file. +func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) { + filename, err := SaveTmpFile("audio", audio) + if err != nil { + return 0, errors.Wrap(err, "failed to save audio to temporary file") + } + defer os.Remove(filename) + + duration, err := GetAudioDuration(ctx, filename) + if err != nil { + return 0, errors.Wrap(err, "failed to get audio tokens") + } + + return int(math.Ceil(duration)) * tokensPerSecond, nil +} + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "/usr/bin/ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration") + } + + // Actually gpt-4-audio calculates tokens with 0.1s precision, + // while whisper calculates tokens with 1s precision + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/common/helper/audio_test.go b/common/helper/audio_test.go new file mode 100644 index 00000000..15f55bbb --- /dev/null +++ b/common/helper/audio_test.go @@ -0,0 +1,55 @@ +package helper + +import ( + "context" + "io" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetAudioDuration(t *testing.T) { + t.Run("should return correct duration for a valid audio file", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "test_audio*.mp3") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + _, err = io.Copy(tmpFile, resp.Body) + require.NoError(t, err) + require.NoError(t, tmpFile.Close()) + + duration, err := GetAudioDuration(context.Background(), tmpFile.Name()) + require.NoError(t, err) + require.Equal(t, duration, 3.904) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioDuration(context.Background(), "non_existent_file.mp3") + require.Error(t, err) + }) +} + +func TestGetAudioTokens(t *testing.T) { + t.Run("should return correct tokens for a valid audio file", func(t *testing.T) { + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + tokens, err := GetAudioTokens(context.Background(), resp.Body, 50) + require.NoError(t, err) + require.Equal(t, tokens, 200) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioTokens(context.Background(), nil, 1) + require.Error(t, err) + }) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index e06dfb6e..662de16c 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -2,8 +2,6 @@ package helper import ( "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/random" "html/template" "log" "net" @@ -11,6 +9,9 @@ import ( "runtime" "strconv" "strings" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/random" ) func OpenBrowser(url string) { @@ -137,3 +138,23 @@ func String2Int(str string) int { } return num } + +func Float64PtrMax(p *float64, maxValue float64) *float64 { + if p == nil { + return nil + } + if *p > maxValue { + return &maxValue + } + return p +} + +func Float64PtrMin(p *float64, minValue float64) *float64 { + if p == nil { + return nil + } + if *p < minValue { + return &minValue + } + return p +} diff --git a/common/helper/time.go b/common/helper/time.go index 302746db..b74cc74a 100644 --- a/common/helper/time.go +++ b/common/helper/time.go @@ -5,6 +5,7 @@ import ( "time" ) +// GetTimestamp get current timestamp in seconds func GetTimestamp() int64 { return time.Now().Unix() } diff --git a/common/init.go b/common/init.go index 6fd84764..391bc088 100644 --- a/common/init.go +++ b/common/init.go @@ -3,11 +3,12 @@ package common import ( "flag" "fmt" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" "log" "os" "path/filepath" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" ) var ( @@ -27,15 +28,15 @@ func printHelp() { func Init() { flag.Parse() - if *PrintVersion { - fmt.Println(Version) - os.Exit(0) - } + // if *PrintVersion { + // fmt.Println(Version) + // os.Exit(0) + // } - if *PrintHelp { - printHelp() - os.Exit(0) - } + // if *PrintHelp { + // printHelp() + // os.Exit(0) + // } if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") == "random_string" { diff --git a/common/message/email.go b/common/message/email.go index 187ac8c3..6ba27019 100644 --- a/common/message/email.go +++ b/common/message/email.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "net" "net/smtp" @@ -18,7 +19,7 @@ func shouldAuth() bool { func SendEmail(subject string, receiver string, content string) error { if receiver == "" { - return fmt.Errorf("receiver is empty") + return errors.Errorf("receiver is empty") } if config.SMTPFrom == "" { // for compatibility config.SMTPFrom = config.SMTPAccount @@ -57,7 +58,7 @@ func SendEmail(subject string, receiver string, content string) error { var err error if config.SMTPPort == 465 { tlsConfig := &tls.Config{ - InsecureSkipVerify: true, + InsecureSkipVerify: false, ServerName: config.SMTPServer, } conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) diff --git a/common/message/main.go b/common/message/main.go index 5ce82a64..a7d3269d 100644 --- a/common/message/main.go +++ b/common/message/main.go @@ -1,7 +1,7 @@ package message import ( - "fmt" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" ) @@ -18,5 +18,5 @@ func Notify(by string, title string, description string, content string) error { if by == ByMessagePusher { return SendMessage(title, description, content) } - return fmt.Errorf("unknown notify method: %s", by) + return errors.Errorf("unknown notify method: %s", by) } diff --git a/common/message/message-pusher.go b/common/message/message-pusher.go index 69949b4b..bc160891 100644 --- a/common/message/message-pusher.go +++ b/common/message/message-pusher.go @@ -3,7 +3,7 @@ package message import ( "bytes" "encoding/json" - "errors" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "net/http" ) diff --git a/common/random/main.go b/common/random/main.go index dbb772cd..c3c69488 100644 --- a/common/random/main.go +++ b/common/random/main.go @@ -1,10 +1,11 @@ package random import ( - "github.com/google/uuid" "math/rand" "strings" "time" + + "github.com/google/uuid" ) func GetUUID() string { diff --git a/common/redis.go b/common/redis.go index f3205567..55d4931c 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,13 +2,15 @@ package common import ( "context" + "os" + "strings" + "time" + "github.com/go-redis/redis/v8" "github.com/songquanpeng/one-api/common/logger" - "os" - "time" ) -var RDB *redis.Client +var RDB redis.Cmdable var RedisEnabled = true // InitRedisClient This function is called after init() @@ -23,13 +25,23 @@ func InitRedisClient() (err error) { logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } - logger.SysLog("Redis is enabled") - opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) - if err != nil { - logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + redisConnString := os.Getenv("REDIS_CONN_STRING") + if os.Getenv("REDIS_MASTER_NAME") == "" { + logger.SysLog("Redis is enabled") + opt, err := redis.ParseURL(redisConnString) + if err != nil { + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) + } + RDB = redis.NewClient(opt) + } else { + // cluster mode + logger.SysLog("Redis cluster mode enabled") + RDB = redis.NewUniversalClient(&redis.UniversalOptions{ + Addrs: strings.Split(redisConnString, ","), + Password: os.Getenv("REDIS_PASSWORD"), + MasterName: os.Getenv("REDIS_MASTER_NAME"), + }) } - RDB = redis.NewClient(opt) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/common/render/render.go b/common/render/render.go index 646b3777..e565c0b7 100644 --- a/common/render/render.go +++ b/common/render/render.go @@ -3,9 +3,10 @@ package render import ( "encoding/json" "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "strings" ) func StringData(c *gin.Context, str string) { diff --git a/controller/auth/github.go b/controller/auth/github.go index 15542655..5d8029b7 100644 --- a/controller/auth/github.go +++ b/controller/auth/github.go @@ -3,10 +3,10 @@ package auth import ( "bytes" "encoding/json" - "errors" "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" diff --git a/controller/auth/lark.go b/controller/auth/lark.go index eb06dde9..d4ff8a57 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -3,17 +3,18 @@ package auth import ( "bytes" "encoding/json" - "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type LarkOAuthResponse struct { @@ -40,7 +41,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { if err != nil { return nil, err } - req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 00000000..7b4ad4b9 --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/auth/wechat.go b/controller/auth/wechat.go index a561aec0..12aec05b 100644 --- a/controller/auth/wechat.go +++ b/controller/auth/wechat.go @@ -2,16 +2,17 @@ package auth import ( "encoding/json" - "errors" "fmt" + "net/http" + "strconv" + "time" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" - "time" ) type wechatLoginResponse struct { diff --git a/controller/billing.go b/controller/billing.go index 0d03e4c1..e837157f 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -17,9 +17,11 @@ func GetSubscription(c *gin.Context) { if config.DisplayTokenStatEnabled { tokenId := c.GetInt(ctxkey.TokenId) token, err = model.GetTokenById(tokenId) - expiredTime = token.ExpiredTime - remainQuota = token.RemainQuota - usedQuota = token.UsedQuota + if err == nil { + expiredTime = token.ExpiredTime + remainQuota = token.RemainQuota + usedQuota = token.UsedQuota + } } else { userId := c.GetInt(ctxkey.Id) remainQuota, err = model.GetUserQuota(userId) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 53592744..c7d2d240 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -2,20 +2,20 @@ package controller import ( "encoding/json" - "errors" "fmt" - "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/channeltype" "io" "net/http" "strconv" "time" "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/channeltype" ) // https://github.com/songquanpeng/one-api/issues/79 @@ -81,6 +81,36 @@ type APGC2DGPTUsageResponse struct { TotalUsed float64 `json:"total_used"` } +type SiliconFlowUsageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Status bool `json:"status"` + Data struct { + ID string `json:"id"` + Name string `json:"name"` + Image string `json:"image"` + Email string `json:"email"` + IsAdmin bool `json:"isAdmin"` + Balance string `json:"balance"` + Status string `json:"status"` + Introduction string `json:"introduction"` + Role string `json:"role"` + ChargeBalance string `json:"chargeBalance"` + TotalBalance string `json:"totalBalance"` + Category string `json:"category"` + } `json:"data"` +} + +type DeepSeekUsageResponse struct { + IsAvailable bool `json:"is_available"` + BalanceInfos []struct { + Currency string `json:"currency"` + TotalBalance string `json:"total_balance"` + GrantedBalance string `json:"granted_balance"` + ToppedUpBalance string `json:"topped_up_balance"` + } `json:"balance_infos"` +} + // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} @@ -101,7 +131,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He return nil, err } if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("status code: %d", res.StatusCode) + return nil, errors.Errorf("status code: %d", res.StatusCode) } body, err := io.ReadAll(res.Body) if err != nil { @@ -166,7 +196,7 @@ func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { return 0, err } if !response.Success { - return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + return 0, errors.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) } channel.UpdateBalance(response.Data.TotalPoints) return response.Data.TotalPoints, nil @@ -203,6 +233,57 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { return response.TotalAvailable, nil } +func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) { + url := "https://api.siliconflow.cn/v1/user/info" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := SiliconFlowUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Code != 20000 { + return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message) + } + balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) { + url := "https://api.deepseek.com/user/balance" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := DeepSeekUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + index := -1 + for i, balanceInfo := range response.BalanceInfos { + if balanceInfo.Currency == "CNY" { + index = i + break + } + } + if index == -1 { + return 0, errors.New("currency CNY not found") + } + balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { @@ -227,6 +308,10 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return updateChannelAPI2GPTBalance(channel) case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) + case channeltype.SiliconFlow: + return updateChannelSiliconFlowBalance(channel) + case channeltype.DeepSeek: + return updateChannelDeepSeekBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/channel-test.go b/controller/channel-test.go index f8327284..a49dacbc 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -3,18 +3,9 @@ package controller import ( "bytes" "encoding/json" - "errors" "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "strconv" - "strings" - "sync" - "time" - "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -28,6 +19,14 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "sync" + "time" ) func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { @@ -66,7 +65,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques apiType := channeltype.ToAPIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + return errors.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) modelName := request.Model @@ -76,9 +75,9 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques if len(modelNames) > 0 { modelName = modelNames[0] } - if modelMap != nil && modelMap[modelName] != "" { - modelName = modelMap[modelName] - } + } + if modelMap != nil && modelMap[modelName] != "" { + modelName = modelMap[modelName] } meta.OriginModelName, meta.ActualModelName = request.Model, modelName request.Model = modelName @@ -103,7 +102,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + return errors.Errorf("%s", respErr.Error.Message), &respErr.Error } if usage == nil { return errors.New("usage is nil"), nil diff --git a/controller/log.go b/controller/log.go index 665f49be..00d062e5 100644 --- a/controller/log.go +++ b/controller/log.go @@ -1,12 +1,13 @@ package controller import ( + "net/http" + "strconv" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" ) func GetAllLogs(c *gin.Context) { diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb..ae900870 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/controller/model.go b/controller/model.go index dcbe709e..e4822f89 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,6 +2,9 @@ package controller import ( "fmt" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" @@ -11,8 +14,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "net/http" - "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -68,6 +69,10 @@ func init() { continue } adaptor := relay.GetAdaptor(i) + if adaptor == nil { + continue + } + channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { @@ -106,6 +111,10 @@ func init() { channelId2Models = make(map[int][]string) for i := 1; i < channeltype.Dummy; i++ { adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) + if adaptor == nil { + continue + } + meta := &meta.Meta{ ChannelType: i, } diff --git a/controller/redemption.go b/controller/redemption.go index 1d0ffbad..88227241 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -1,14 +1,15 @@ package controller import ( + "net/http" + "strconv" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" ) func GetAllRedemptions(c *gin.Context) { diff --git a/controller/relay.go b/controller/relay.go index 49358e25..94c75104 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" @@ -26,7 +27,8 @@ import ( func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough @@ -45,10 +47,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) - if config.DebugEnabled { - requestBody, _ := common.GetRequestBody(c) - logger.Debugf(ctx, "request body: %s", string(requestBody)) - } channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) @@ -60,11 +58,11 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes - if !shouldRetry(c, bizErr.StatusCode) { - logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) + if err := shouldRetry(c, bizErr.StatusCode); err != nil { + logger.Errorf(ctx, "relay error happen, won't retry since of %v", err.Error()) retryTimes = 0 } for i := retryTimes; i > 0; i-- { @@ -87,9 +85,9 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is in race condition - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) } + if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" @@ -103,26 +101,23 @@ func Relay(c *gin.Context) { } } -func shouldRetry(c *gin.Context, statusCode int) bool { - if _, ok := c.Get(ctxkey.SpecificChannelId); ok { - return false +// shouldRetry returns nil if should retry, otherwise returns error +func shouldRetry(c *gin.Context, statusCode int) error { + if v, ok := c.Get(ctxkey.SpecificChannelId); ok { + return errors.Errorf("specific channel = %v", v) } + if statusCode == http.StatusTooManyRequests { - return true + return errors.Errorf("status code = %d", statusCode) } if statusCode/100 == 5 { - return true + return errors.Errorf("status code = %d", statusCode) } - if statusCode == http.StatusBadRequest { - return false - } - if statusCode/100 == 2 { - return false - } - return true + + return nil } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { diff --git a/controller/token.go b/controller/token.go index 668ccd97..7c26dbf8 100644 --- a/controller/token.go +++ b/controller/token.go @@ -2,17 +2,42 @@ package controller import ( "fmt" + "net/http" + "strconv" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" - "net/http" - "strconv" ) +func GetRequestCost(c *gin.Context) { + reqId := c.Param("request_id") + if reqId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "request_id 不能为空", + }) + return + } + + docu, err := model.GetCostByRequestId(reqId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, docu) +} + func GetAllTokens(c *gin.Context) { userId := c.GetInt(ctxkey.Id) p, _ := strconv.Atoi(c.Query("p")) @@ -107,22 +132,24 @@ func GetTokenStatus(c *gin.Context) { }) } -func validateToken(c *gin.Context, token model.Token) error { +func validateToken(c *gin.Context, token *model.Token) error { if len(token.Name) > 30 { return fmt.Errorf("令牌名称过长") } + if token.Subnet != nil && *token.Subnet != "" { err := network.IsValidSubnets(*token.Subnet) if err != nil { return fmt.Errorf("无效的网段:%s", err.Error()) } } + return nil } func AddToken(c *gin.Context) { - token := model.Token{} - err := c.ShouldBindJSON(&token) + token := new(model.Token) + err := c.ShouldBindJSON(token) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -130,6 +157,7 @@ func AddToken(c *gin.Context) { }) return } + err = validateToken(c, token) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -185,11 +213,18 @@ func DeleteToken(c *gin.Context) { return } -func UpdateToken(c *gin.Context) { - userId := c.GetInt(ctxkey.Id) - statusOnly := c.Query("status_only") - token := model.Token{} - err := c.ShouldBindJSON(&token) +type consumeTokenRequest struct { + // AddUsedQuota add or subtract used quota from another source + AddUsedQuota uint `json:"add_used_quota" gorm:"-"` + // AddReason is the reason for adding or subtracting used quota + AddReason string `json:"add_reason" gorm:"-"` +} + +// ConsumeToken consume token from another source, +// let one-api to gather billing from different sources. +func ConsumeToken(c *gin.Context) { + tokenPatch := new(consumeTokenRequest) + err := c.ShouldBindJSON(tokenPatch) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -197,6 +232,106 @@ func UpdateToken(c *gin.Context) { }) return } + + userID := c.GetInt(ctxkey.Id) + tokenID := c.GetInt(ctxkey.TokenId) + cleanToken, err := model.GetTokenByIds(tokenID, userID) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + if cleanToken.Status != model.TokenStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌未启用", + }) + return + } + + if cleanToken.Status == model.TokenStatusExpired && + cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", + }) + return + } + if cleanToken.Status == model.TokenStatusExhausted && + cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", + }) + return + } + + // let admin to add or subtract used quota, + // make it possible to aggregate billings from different sources. + if cleanToken.RemainQuota < int64(tokenPatch.AddUsedQuota) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "剩余额度不足", + }) + return + } + + if err = model.DecreaseTokenQuota(cleanToken.Id, int64(tokenPatch.AddUsedQuota)); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + model.RecordConsumeLog(c.Request.Context(), + userID, 0, 0, 0, tokenPatch.AddReason, cleanToken.Name, + int64(tokenPatch.AddUsedQuota), + fmt.Sprintf("外部(%s)消耗 %s", + tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota)))) + + err = cleanToken.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanToken, + }) + + return +} + +func UpdateToken(c *gin.Context) { + userId := c.GetInt(ctxkey.Id) + statusOnly := c.Query("status_only") + tokenPatch := new(model.Token) + err := c.ShouldBindJSON(tokenPatch) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + token := new(model.Token) + if err = copier.Copy(token, tokenPatch); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = validateToken(c, token) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -205,6 +340,7 @@ func UpdateToken(c *gin.Context) { }) return } + cleanToken, err := model.GetTokenByIds(token.Id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -213,33 +349,50 @@ func UpdateToken(c *gin.Context) { }) return } - if token.Status == model.TokenStatusEnabled { - if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + + switch token.Status { + case model.TokenStatusEnabled: + if cleanToken.Status == model.TokenStatusExpired && + cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 && + token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + if cleanToken.Status == model.TokenStatusExhausted && + cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota && + token.RemainQuota <= 0 && !token.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", }) return } + case model.TokenStatusExhausted: + if token.RemainQuota > 0 || token.UnlimitedQuota { + token.Status = model.TokenStatusEnabled + } + case model.TokenStatusExpired: + if token.ExpiredTime == -1 || token.ExpiredTime > helper.GetTimestamp() { + token.Status = model.TokenStatusEnabled + } } + if statusOnly != "" { cleanToken.Status = token.Status } else { // If you add more fields, please also update token.Update() cleanToken.Name = token.Name cleanToken.ExpiredTime = token.ExpiredTime - cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.Models = token.Models cleanToken.Subnet = token.Subnet + cleanToken.RemainQuota = token.RemainQuota + cleanToken.Status = token.Status } + err = cleanToken.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/user.go b/controller/user.go index e79881c2..05c4d66a 100644 --- a/controller/user.go +++ b/controller/user.go @@ -3,17 +3,17 @@ package controller import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/ctxkey" - "github.com/songquanpeng/one-api/common/random" - "github.com/songquanpeng/one-api/model" "net/http" "strconv" "time" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/model" ) type LoginRequest struct { @@ -77,6 +77,12 @@ func SetupLogin(user *model.User, c *gin.Context) { }) return } + + // set auth header + // c.Set("id", user.Id) + // GenerateAccessToken(c) + // c.Header("Authorization", user.AccessToken) + cleanUser := model.User{ Id: user.Id, Username: user.Username, @@ -173,7 +179,6 @@ func Register(c *gin.Context) { }) return } - c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -343,6 +348,16 @@ func GetAffCode(c *gin.Context) { return } +// GetSelfByToken get user by openai api token +func GetSelfByToken(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "uid": c.GetInt("id"), + "token_id": c.GetInt("token_id"), + "username": c.GetString("username"), + }) + return +} + func GetSelf(c *gin.Context) { id := c.GetInt(ctxkey.Id) user, err := model.GetUserById(id, false) diff --git a/go.mod b/go.mod index ada53bc3..5ce08ca4 100644 --- a/go.mod +++ b/go.mod @@ -1,111 +1,126 @@ module github.com/songquanpeng/one-api // +heroku goVersion go1.18 -go 1.20 +go 1.23 + +toolchain go1.23.0 require ( - cloud.google.com/go/iam v1.1.10 - github.com/aws/aws-sdk-go-v2 v1.27.0 - github.com/aws/aws-sdk-go-v2/credentials v1.17.15 - github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 - github.com/gin-contrib/cors v1.7.2 - github.com/gin-contrib/gzip v1.0.1 - github.com/gin-contrib/sessions v1.0.1 - github.com/gin-contrib/static v1.1.2 + cloud.google.com/go/iam v1.3.1 + github.com/Laisky/go-utils/v4 v4.10.0 + github.com/aws/aws-sdk-go-v2 v1.32.7 + github.com/aws/aws-sdk-go-v2/credentials v1.17.48 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.1 + github.com/gin-contrib/cors v1.7.3 + github.com/gin-contrib/gzip v1.1.0 + github.com/gin-contrib/sessions v1.0.2 + github.com/gin-contrib/static v1.1.3 github.com/gin-gonic/gin v1.10.0 - github.com/go-playground/validator/v10 v10.20.0 + github.com/go-playground/validator/v10 v10.23.0 github.com/go-redis/redis/v8 v8.11.5 github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 - github.com/gorilla/websocket v1.5.1 + github.com/gorilla/websocket v1.5.3 github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.7 github.com/smartystreets/goconvey v1.8.1 - github.com/stretchr/testify v1.9.0 - golang.org/x/crypto v0.24.0 - golang.org/x/image v0.18.0 - google.golang.org/api v0.187.0 - gorm.io/driver/mysql v1.5.6 - gorm.io/driver/postgres v1.5.7 - gorm.io/driver/sqlite v1.5.5 - gorm.io/gorm v1.25.10 + github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.32.0 + golang.org/x/image v0.23.0 + golang.org/x/sync v0.10.0 + google.golang.org/api v0.215.0 + gorm.io/driver/mysql v1.5.7 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 + gorm.io/gorm v1.25.12 ) require ( - cloud.google.com/go/auth v0.6.1 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect - filippo.io/edwards25519 v1.1.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect - github.com/aws/smithy-go v1.20.2 // indirect - github.com/bytedance/sonic v1.11.6 // indirect - github.com/bytedance/sonic/loader v0.1.1 // indirect + cloud.google.com/go/auth v0.13.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + github.com/GoWebProd/gip v0.0.0-20230623090727-b60d41d5d320 // indirect + github.com/GoWebProd/uuid7 v0.0.0-20231130161441-17ee54b097d4 // indirect + github.com/Laisky/errors/v2 v2.0.1 // indirect + github.com/Laisky/fast-skiplist/v2 v2.0.1 // indirect + github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be // indirect + github.com/Laisky/golang-fifo v1.0.1-0.20240403091456-fc83d5e38c0b // indirect + github.com/Laisky/graphql v1.0.6 // indirect + github.com/Laisky/zap v1.27.1-0.20240628060440-a253d90172e3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect + github.com/aws/smithy-go v1.22.1 // indirect + github.com/bytedance/sonic v1.12.6 // indirect + github.com/bytedance/sonic/loader v0.2.1 // indirect + github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gabriel-vasile/mimetype v1.4.7 // indirect + github.com/gammazero/deque v0.2.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect - github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/goccy/go-json v0.10.3 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.4 // indirect - github.com/google/s2a-go v0.1.7 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect - github.com/googleapis/gax-go/v2 v2.12.5 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/goccy/go-json v0.10.4 // indirect + github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932 // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gorilla/context v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/sessions v1.2.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smarty/assertions v1.15.0 // indirect + github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect - go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect - go.opentelemetry.io/otel v1.24.0 // indirect - go.opentelemetry.io/otel/metric v1.24.0 // indirect - go.opentelemetry.io/otel/trace v1.24.0 // indirect - golang.org/x/arch v0.8.0 // indirect - golang.org/x/net v0.26.0 // indirect - golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect - golang.org/x/time v0.5.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect - google.golang.org/grpc v1.64.1 // indirect - google.golang.org/protobuf v1.34.2 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect + go.opentelemetry.io/otel v1.29.0 // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect + go.uber.org/automaxprocs v1.5.3 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/arch v0.12.0 // indirect + golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/oauth2 v0.24.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.8.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 // indirect + google.golang.org/grpc v1.67.3 // indirect + google.golang.org/protobuf v1.36.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 53db8df2..6d01bb31 100644 --- a/go.sum +++ b/go.sum @@ -1,127 +1,129 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go/auth v0.6.1 h1:T0Zw1XM5c1GlpN2HYr2s+m3vr1p2wy+8VN+Z1FKxW38= -cloud.google.com/go/auth v0.6.1/go.mod h1:eFHG7zDzbXHKmjJddFG/rBlcGp6t25SwRUiEQSlO4x4= -cloud.google.com/go/auth/oauth2adapt v0.2.2 h1:+TTV8aXpjeChS9M+aTtN/TjdQnzJvmzKFt//oWu7HX4= -cloud.google.com/go/auth/oauth2adapt v0.2.2/go.mod h1:wcYjgpZI9+Yu7LyYBg4pqSiaRkfEK3GQcpb7C/uyF1Q= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= -cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= -github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= -github.com/aws/aws-sdk-go-v2/credentials v1.17.15 h1:YDexlvDRCA8ems2T5IP1xkMtOZ1uLJOCJdTr0igs5zo= -github.com/aws/aws-sdk-go-v2/credentials v1.17.15/go.mod h1:vxHggqW6hFNaeNC0WyXS3VdyjcV0a4KMUY4dKJ96buU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3 h1:Fihjyd6DeNjcawBEGLH9dkIEUi6AdhucDKPE9nJ4QiY= -github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.3/go.mod h1:opvUj3ismqSCxYc+m4WIjPL0ewZGtvp0ess7cKvBPOQ= -github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= -github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= -github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= -github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= -github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs= +cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q= +cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= +cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/iam v1.3.1 h1:KFf8SaT71yYq+sQtRISn90Gyhyf4X8RGgeAVC8XGf3E= +cloud.google.com/go/iam v1.3.1/go.mod h1:3wMtuyT4NcbnYNPLMBzYRFiEfjKfJlLVLrisE7bwm34= +github.com/GoWebProd/gip v0.0.0-20230623090727-b60d41d5d320 h1:lddR7TsA0fUX8Kh+oc01z4GwmCoBveT79zhNLK43xLk= +github.com/GoWebProd/gip v0.0.0-20230623090727-b60d41d5d320/go.mod h1:eTC6ev1JFq+zoOOS5WKuHdBwtihV/9/Ouv3fZ3ufS0A= +github.com/GoWebProd/uuid7 v0.0.0-20231130161441-17ee54b097d4 h1:IjxKU4UMzoALLBo3JF7QNi5E0H22R2lDKT3RM9yNCQU= +github.com/GoWebProd/uuid7 v0.0.0-20231130161441-17ee54b097d4/go.mod h1:YIx3++ypr3VYDYlz62Zs6zxq/iPT5e9vShuqxwL/5Us= +github.com/Laisky/errors/v2 v2.0.1 h1:yqCBrRzaP012AMB+7fVlXrP34OWRHrSO/hZ38CFdH84= +github.com/Laisky/errors/v2 v2.0.1/go.mod h1:mTn1LHSmKm4CYug0rpYO7rz13dp/DKrtzlSELSrxvT0= +github.com/Laisky/fast-skiplist/v2 v2.0.1 h1:mZD3G/cwNovXsd21Vyvt3HCI9dEg1V7OD64qXsUUgpQ= +github.com/Laisky/fast-skiplist/v2 v2.0.1/go.mod h1:JlDGOmsJOwW7Uo46L9aVG7nJAeqP7X7nfU5TABOiiE8= +github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be h1:7Rxhm6IjOtDAyj8eScOFntevwzkWhx94zi48lxo4m4w= +github.com/Laisky/go-chaining v0.0.0-20180507092046-43dcdc5a21be/go.mod h1:1mdzaETo0kjvCQICPSePsoaatJN4l7JvEA1200lyevo= +github.com/Laisky/go-utils/v4 v4.10.0 h1:kSYHk0ONde1ZVMNVw/sLAOAKu/+9mOS9KtVLtY0upAo= +github.com/Laisky/go-utils/v4 v4.10.0/go.mod h1:TepxY90+WGujsezm9rN7Ctk8faB2Xgtz4bQ/C8PHBJ8= +github.com/Laisky/golang-fifo v1.0.1-0.20240403091456-fc83d5e38c0b h1:o2BuVyXFkDTkEiuz1Ur32jGvaErgEHqhb8AtTIkrvE0= +github.com/Laisky/golang-fifo v1.0.1-0.20240403091456-fc83d5e38c0b/go.mod h1:j90tUqwBaEncIzpAd6ZGPZHWjclgAyMY8fdOqsewitE= +github.com/Laisky/graphql v1.0.6 h1:NEULGxfxo+wbsW2OmqBXOMNUGgqo8uFjWNabwuNK10g= +github.com/Laisky/graphql v1.0.6/go.mod h1:zaKVqXmMQTnTkFJ2AA53oyBWMzlGCnzr3aodKTrtOxI= +github.com/Laisky/zap v1.27.1-0.20240628060440-a253d90172e3 h1:SD0siYXoInGc6MqVsmJrBJl4TNHYXfeGu92fRSUNnas= +github.com/Laisky/zap v1.27.1-0.20240628060440-a253d90172e3/go.mod h1:HABqM5YDQlPq8w+Pmp9h/x9F6Vy+3oHBLP+2+pBoaJw= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/credentials v1.17.48 h1:IYdLD1qTJ0zanRavulofmqut4afs45mOWEI+MzZtTfQ= +github.com/aws/aws-sdk-go-v2/credentials v1.17.48/go.mod h1:tOscxHN3CGmuX9idQ3+qbkzrjVIx32lqDSU1/0d/qXs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26/go.mod h1:FR8f4turZtNy6baO0KJ5FJUmXH/cSkI9fOngs0yl6mA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 h1:zXFLuEuMMUOvEARXFUVJdfqZ4bvvSgdGRq/ATcrQxzM= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26/go.mod h1:3o2Wpy0bogG1kyOPrgkXA8pgIfEEv0+m19O9D5+W8y8= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.1 h1:rqrvjFScEwD7VfP4L0hhnrXyTkgUkpQWAdwOrW2slOo= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.1/go.mod h1:Vn5GopXsOAC6kbwzjfM6V37dxc4mo4J4xCRiF27pSZA= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/brianvoe/gofakeit/v6 v6.23.2 h1:lVde18uhad5wII/f5RMVFLtdQNE0HaGFuBUXmYKk8i8= +github.com/brianvoe/gofakeit/v6 v6.23.2/go.mod h1:Ow6qC71xtwm79anlwKRlWZW6zVq9D2XHE4QSSMP/rU8= +github.com/bytedance/sonic v1.12.6 h1:/isNmCUF2x3Sh8RAp/4mh4ZGkcFAX/hLrzrK3AvpRzk= +github.com/bytedance/sonic v1.12.6/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/bytedance/sonic/loader v0.2.1 h1:1GgorWTqf12TA8mma4DDSbaQigE2wOgQo7iCjjJv3+E= +github.com/bytedance/sonic/loader v0.2.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/deckarep/golang-set/v2 v2.1.0 h1:g47V4Or+DUdzbs8FxCCmgb6VYd+ptPAngjM6dtGktsI= +github.com/deckarep/golang-set/v2 v2.1.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= -github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= -github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= -github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= -github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E= -github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= -github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= -github.com/gin-contrib/sessions v1.0.1 h1:3hsJyNs7v7N8OtelFmYXFrulAf6zSR7nW/putcPEHxI= -github.com/gin-contrib/sessions v1.0.1/go.mod h1:ouxSFM24/OgIud5MJYQJLpy6AwxQ5EYO9yLhbtObGkM= +github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA= +github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= +github.com/gin-contrib/cors v1.7.3 h1:hV+a5xp8hwJoTw7OY+a70FsL8JkVVFTXw9EcfrYUdns= +github.com/gin-contrib/cors v1.7.3/go.mod h1:M3bcKZhxzsvI+rlRSkkxHyljJt1ESd93COUvemZ79j4= +github.com/gin-contrib/gzip v1.1.0 h1:kVw7Nr9M+Z6Ch4qo7aGMbiqxDeyQFru+07MgAcUF62M= +github.com/gin-contrib/gzip v1.1.0/go.mod h1:iHJXCup4CWiKyPUEl+GwkHjchl+YyYuMKbOCiXujPIA= +github.com/gin-contrib/sessions v1.0.2 h1:UaIjUvTH1cMeOdj3in6dl+Xb6It8RiKRF9Z1anbUyCA= +github.com/gin-contrib/sessions v1.0.2/go.mod h1:KxKxWqWP5LJVDCInulOl4WbLzK2KSPlLesfZ66wRvMs= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-contrib/static v1.1.2 h1:c3kT4bFkUJn2aoRU3s6XnMjJT8J6nNWJkR0NglqmlZ4= -github.com/gin-contrib/static v1.1.2/go.mod h1:Fw90ozjHCmZBWbgrsqrDvO28YbhKEKzKp8GixhR4yLw= +github.com/gin-contrib/static v1.1.3 h1:WLOpkBtMDJ3gATFZgNJyVibFMio/UHonnueqJsQ0w4U= +github.com/gin-contrib/static v1.1.3/go.mod h1:zejpJ/YWp8cZj/6EpiL5f/+skv5daQTNwRx1E8Pci30= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-json-experiment/json v0.0.0-20231011163920-8aa127fd5801 h1:PRieymvnGuBZUnWVQPBOemqlIhRznqtSxs/1LqlWe20= +github.com/go-json-experiment/json v0.0.0-20231011163920-8aa127fd5801/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= -github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= -github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-playground/validator/v10 v10.23.0 h1:/PwmTwZhS0dPkav3cdK9kV1FsAmrL8sThn8IHr/sO+o= +github.com/go-playground/validator/v10 v10.23.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= -github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM= +github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932 h1:5/4TSDzpDnHQ8rKEEQBjRlYx77mHOvXu08oGchxej7o= +github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932/go.mod h1:cC6EdPbj/17GFCPDK39NRarlMI+kt+O60S12cNB5J9Y= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= -github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= -github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= -github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.5 h1:8gw9KZK8TiVKB6q3zHY3SBzLnrGp6HQjyfYBYGmXdxA= -github.com/googleapis/gax-go/v2 v2.12.5/go.mod h1:BUDKcWo+RaKq5SC9vVYL0wLADa3VcfswbOMMRmB9H3E= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= @@ -130,12 +132,13 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/graph-gophers/graphql-go v1.5.0/go.mod h1:YtmJZDLbF1YYNrlNAuiO5zAStUWc3XZT07iGsVqe1Os= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= @@ -153,10 +156,11 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= @@ -171,147 +175,132 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= -github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= +github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= +github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= -go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo= -go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo= -go.opentelemetry.io/otel/metric v1.24.0 h1:6EhoGWWK28x1fbpA4tYTOWBkPefTDQnb8WSGXlc88kI= -go.opentelemetry.io/otel/metric v1.24.0/go.mod h1:VYhLe1rFfxuTXLgj4CBiyz+9WYBA8pNGJgDcSFRKBco= -go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y1YELI= -go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= -golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/trace v1.6.3/go.mod h1:GNJQusJlUgZl9/TQBPKU/Y/ty+0iVB5fjhKeJGZPGFs= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= +go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= +go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= +golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= -golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68= +golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= -golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.187.0 h1:Mxs7VATVC2v7CY+7Xwm4ndkX71hpElcvx0D1Ji/p1eo= -google.golang.org/api v0.187.0/go.mod h1:KIHlTc4x7N7gKKuVsdmfBXN13yEEWXWFURWY6SBp2gk= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 h1:MuYw1wJzT+ZkybKfaOXKp5hJiZDn2iHaXRw0mRYdHSc= -google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4/go.mod h1:px9SlOOZBg1wM1zdnr8jEL4CNGUBZ+ZKYtNPApNQc4c= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d h1:k3zyW3BYYR30e8v3x0bTDdE9vpYFjZHK+HcyqkrppWk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/api v0.215.0 h1:jdYF4qnyczlEz2ReWIsosNLDuzXyvFHJtI5gcr0J7t0= +google.golang.org/api v0.215.0/go.mod h1:fta3CVtuJYOEdugLNWm6WodzOS8KdFckABwN4I40hzY= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 h1:CkkIfIt50+lT6NHAVoRYEyAvQGFM7xEwXUUywFvEb3Q= +google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576/go.mod h1:1R3kvZ1dtP3+4p4d3G8uJ8rFk/fWlScl38vanWACI08= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8 h1:TqExAhdPaB60Ux47Cn0oLV07rGnxZzIsaRhQaqS666A= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241223144023-3abc09e42ca8/go.mod h1:lcTa1sDdWEIHMWlITnIczmw5w60CF9ffkb8Z+DVmmjA= +google.golang.org/grpc v1.67.3 h1:OgPcDAFKHnH8X3O4WcO4XUc8GRDeKsKReqbQtiCj7N8= +google.golang.org/grpc v1.67.3/go.mod h1:YGaHCc6Oap+FzBJTZLBzkGSYt/cvGPFTPxkn7QfSU8s= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= -gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= -gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= -gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= -gorm.io/driver/sqlite v1.5.5 h1:7MDMtUZhV065SilG62E0MquljeArQZNfJnjd9i9gx3E= -gorm.io/driver/sqlite v1.5.5/go.mod h1:6NgQ7sQWAIFsPrJJl1lSNSu2TABh0ZZ/zm5fosATavE= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= +gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= +gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= +gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= -gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go index 67a3cd95..97d712b8 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,11 @@ package main import ( "embed" + "encoding/base64" "fmt" + "os" + "strconv" + "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" @@ -16,8 +20,6 @@ import ( "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/router" - "os" - "strconv" ) //go:embed web/build/* @@ -99,8 +101,15 @@ func main() { server.Use(middleware.RequestId()) middleware.SetUpLogger(server) // Initialize session store - store := cookie.NewStore([]byte(config.SessionSecret)) - server.Use(sessions.Sessions("session", store)) + sessionSecret, err := base64.StdEncoding.DecodeString(config.SessionSecret) + if err != nil { + logger.SysLog("session secret is not base64 encoded, using raw value instead") + store := cookie.NewStore([]byte(config.SessionSecret)) + server.Use(sessions.Sessions("session", store)) + } else { + store := cookie.NewStore(sessionSecret, sessionSecret) + server.Use(sessions.Sessions("session", store)) + } router.SetRouter(server, buildFS) var port = os.Getenv("PORT") diff --git a/middleware/auth.go b/middleware/auth.go index e0019838..7aa2b26f 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,14 +2,16 @@ package middleware import ( "fmt" + "net/http" + "strings" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" - "net/http" - "strings" ) func authHelper(c *gin.Context, minRole int) { @@ -19,6 +21,7 @@ func authHelper(c *gin.Context, minRole int) { id := session.Get("id") status := session.Get("status") if username == nil { + logger.SysLog("no user session found, try to use access token") // Check access token accessToken := c.Request.Header.Get("Authorization") if accessToken == "" { @@ -29,6 +32,7 @@ func authHelper(c *gin.Context, minRole int) { c.Abort() return } + user := model.ValidateAccessToken(accessToken) if user != nil && user.Username != "" { // Token is valid @@ -93,7 +97,7 @@ func TokenAuth() func(c *gin.Context) { ctx := c.Request.Context() key := c.Request.Header.Get("Authorization") key = strings.TrimPrefix(key, "Bearer ") - key = strings.TrimPrefix(key, "sk-") + key = strings.TrimPrefix(strings.TrimPrefix(key, "sk-"), "laisky-") parts := strings.Split(key, "-") key = parts[0] token, err := model.ValidateUserToken(key) diff --git a/middleware/distributor.go b/middleware/distributor.go index 0c4b04c3..1d376dfc 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,17 +2,21 @@ package middleware import ( "fmt" + "net/http" + "strconv" + "strings" + + gutils "github.com/Laisky/go-utils/v4" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "net/http" - "strconv" ) type ModelRequest struct { - Model string `json:"model"` + Model string `json:"model" form:"model"` } func Distribute() func(c *gin.Context) { @@ -58,9 +62,32 @@ func Distribute() func(c *gin.Context) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + // one channel could relates to multiple groups, + // and each groud has individual ratio, + // set minimal group ratio as channel_ratio + var minimalRatio float64 = -1 + for _, grp := range strings.Split(channel.Group, ",") { + v := ratio.GetGroupRatio(grp) + if minimalRatio < 0 || v < minimalRatio { + minimalRatio = v + } + } + logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio)) + c.Set(ctxkey.ChannelRatio, minimalRatio) + c.Set(ctxkey.ChannelModel, channel) + + // generate an unique cost id for each request + if _, ok := c.Get(ctxkey.RequestId); !ok { + c.Set(ctxkey.RequestId, gutils.UUID7()) + } + c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelName, channel.Name) + c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type")) + if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { + c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) + } c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/gzip.go b/middleware/gzip.go new file mode 100644 index 00000000..4d4ce0c2 --- /dev/null +++ b/middleware/gzip.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "compress/gzip" + "github.com/gin-gonic/gin" + "io" + "net/http" +) + +func GzipDecodeMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.GetHeader("Content-Encoding") == "gzip" { + gzipReader, err := gzip.NewReader(c.Request.Body) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + defer gzipReader.Close() + + // Replace the request body with the decompressed data + c.Request.Body = io.NopCloser(gzipReader) + } + + // Continue processing the request + c.Next() + } +} diff --git a/middleware/logger.go b/middleware/logger.go index 191364f8..587d748c 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/recover.go b/middleware/recover.go index cfc3f827..a690c77b 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -14,11 +14,11 @@ func RelayPanicRecover() gin.HandlerFunc { defer func() { if err := recover(); err != nil { ctx := c.Request.Context() - logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err)) - logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) - logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path)) + logger.Errorf(ctx, "panic detected: %v", err) + logger.Errorf(ctx, "stacktrace from panic: %s", string(debug.Stack())) + logger.Errorf(ctx, "request: %s %s", c.Request.Method, c.Request.URL.Path) body, _ := common.GetRequestBody(c) - logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body))) + logger.Errorf(ctx, "request body: %s", string(body)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err), diff --git a/middleware/request-id.go b/middleware/request-id.go index bef09e32..c1f3adc2 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/utils.go b/middleware/utils.go index 4d2f8092..46120f2a 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,12 +1,13 @@ package middleware import ( - "fmt" + "strings" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -24,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { - return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed") } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + + switch { + case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"): if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + case strings.HasSuffix(c.Request.URL.Path, "embeddings"): if modelRequest.Model == "" { modelRequest.Model = c.Param("model") } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"), + strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"): if modelRequest.Model == "" { modelRequest.Model = "dall-e-2" } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"), + strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"): if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } } + return modelRequest.Model, nil } diff --git a/model/cache.go b/model/cache.go index cfb0f8a4..997d8361 100644 --- a/model/cache.go +++ b/model/cache.go @@ -3,8 +3,8 @@ package model import ( "context" "encoding/json" - "errors" "fmt" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" @@ -51,6 +51,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { } return &token, nil } + err = json.Unmarshal([]byte(tokenObjectString), &token) return &token, err } diff --git a/model/channel.go b/model/channel.go index 759dfd4f..4b0f4b01 100644 --- a/model/channel.go +++ b/model/channel.go @@ -37,6 +37,7 @@ type Channel struct { ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Config string `json:"config"` + SystemPrompt *string `json:"system_prompt" gorm:"type:text"` } type ChannelConfig struct { diff --git a/model/cost.go b/model/cost.go new file mode 100644 index 00000000..82efe855 --- /dev/null +++ b/model/cost.go @@ -0,0 +1,75 @@ +package model + +import ( + "fmt" + "math/rand" + "sync" + + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" +) + +type UserRequestCost struct { + Id int `json:"id"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UserID int `json:"user_id"` + RequestID string `json:"request_id"` + Quota int64 `json:"quota"` + CostUSD float64 `json:"cost_usd" gorm:"-"` +} + +// NewUserRequestCost create a new UserRequestCost +func NewUserRequestCost(userID int, quotaID string, quota int64) *UserRequestCost { + return &UserRequestCost{ + CreatedTime: helper.GetTimestamp(), + UserID: userID, + RequestID: quotaID, + Quota: quota, + } +} + +func (docu *UserRequestCost) Insert() error { + go removeOldRequestCost() + + err := DB.Create(docu).Error + return errors.Wrap(err, "failed to insert UserRequestCost") +} + +// GetCostByRequestId get cost by request id +func GetCostByRequestId(reqid string) (*UserRequestCost, error) { + if reqid == "" { + return nil, errors.New("request id is empty") + } + + docu := &UserRequestCost{RequestID: reqid} + var err error = nil + if err = DB.First(docu, "request_id = ?", reqid).Error; err != nil { + return nil, errors.Wrap(err, "failed to get cost by request id") + } + + docu.CostUSD = float64(docu.Quota) / 500000 + return docu, nil +} + +var muRemoveOldRequestCost sync.Mutex + +// removeOldRequestCost remove old request cost data, +// this function will be executed every 1/1000 times. +func removeOldRequestCost() { + if rand.Float32() > 0.001 { + return + } + + if ok := muRemoveOldRequestCost.TryLock(); !ok { + return + } + defer muRemoveOldRequestCost.Unlock() + + err := DB. + Where("created_time < ?", helper.GetTimestamp()-3600*24*7). + Delete(&UserRequestCost{}).Error + if err != nil { + logger.SysError(fmt.Sprintf("failed to remove old request cost: %s", err.Error())) + } +} diff --git a/model/log.go b/model/log.go index 6fba776a..58fdd513 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -152,7 +153,11 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { - tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull)) if username != "" { tx = tx.Where("username = ?", username) } @@ -176,7 +181,11 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + ifnull := "ifnull" + if common.UsingPostgreSQL { + ifnull = "COALESCE" + } + tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull)) if username != "" { tx = tx.Where("username = ?", username) } diff --git a/model/main.go b/model/main.go index 72e271a0..fcf2b3ae 100644 --- a/model/main.go +++ b/model/main.go @@ -3,6 +3,11 @@ package model import ( "database/sql" "fmt" + "os" + "strings" + "time" + + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/env" @@ -13,9 +18,6 @@ import ( "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" - "os" - "strings" - "time" ) var DB *gorm.DB @@ -28,7 +30,7 @@ func CreateRootAccountIfNeed() error { logger.SysLog("no user exists, creating a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { - return err + return errors.WithStack(err) } accessToken := random.GetUUID() if config.InitialRootAccessToken != "" { @@ -157,7 +159,7 @@ func migrateDB() error { if err = DB.AutoMigrate(&Log{}); err != nil { return err } - if err = DB.AutoMigrate(&Channel{}); err != nil { + if err = DB.AutoMigrate(&UserRequestCost{}); err != nil { return err } return nil @@ -220,10 +222,10 @@ func setDBConns(db *gorm.DB) *sql.DB { func closeDB(db *gorm.DB) error { sqlDB, err := db.DB() if err != nil { - return err + return errors.WithStack(err) } err = sqlDB.Close() - return err + return errors.WithStack(err) } func CloseDB() error { diff --git a/model/option.go b/model/option.go index bed8d4c3..8fd30aee 100644 --- a/model/option.go +++ b/model/option.go @@ -28,6 +28,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/redemption.go b/model/redemption.go index 45871a71..6c1a63d3 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -1,8 +1,8 @@ package model import ( - "errors" "fmt" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "gorm.io/gorm" diff --git a/model/token.go b/model/token.go index 96e6b491..a8351b24 100644 --- a/model/token.go +++ b/model/token.go @@ -1,8 +1,9 @@ package model import ( - "errors" "fmt" + + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -30,7 +31,7 @@ type Token struct { RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota - Models *string `json:"models" gorm:"default:''"` // allowed models + Models *string `json:"models" gorm:"type:text"` // allowed models Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } @@ -65,9 +66,10 @@ func ValidateUserToken(key string) (token *Token, err error) { if err != nil { logger.SysError("CacheGetTokenByKey failed: " + err.Error()) if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, errors.New("无效的令牌") + return nil, errors.Wrap(err, "token not found") } - return nil, errors.New("令牌验证失败") + + return nil, errors.Wrap(err, "failed to get token by key") } if token.Status == TokenStatusExhausted { return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) @@ -121,30 +123,40 @@ func GetTokenById(id int) (*Token, error) { return &token, err } -func (token *Token) Insert() error { +func (t *Token) Insert() error { var err error - err = DB.Create(token).Error + err = DB.Create(t).Error return err } // Update Make sure your token's fields is completed, because this will update non-zero values -func (token *Token) Update() error { +func (t *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error + err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error return err } -func (token *Token) SelectUpdate() error { +func (t *Token) SelectUpdate() error { // This can update zero values - return DB.Model(token).Select("accessed_time", "status").Updates(token).Error + return DB.Model(t).Select("accessed_time", "status").Updates(t).Error } -func (token *Token) Delete() error { +func (t *Token) Delete() error { var err error - err = DB.Delete(token).Error + err = DB.Delete(t).Error return err } +func (t *Token) GetModels() string { + if t == nil { + return "" + } + if t.Models == nil { + return "" + } + return *t.Models +} + func DeleteTokenById(id int, userId int) (err error) { // Why we need userId here? In case user want to delete other's token. if id == 0 || userId == 0 { @@ -254,14 +266,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) + if err != nil { + return err + } if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) } else { err = IncreaseUserQuota(token.UserId, -quota) } - if err != nil { - return err - } if !token.UnlimitedQuota { if quota > 0 { err = DecreaseTokenQuota(tokenId, quota) diff --git a/model/user.go b/model/user.go index 924d72f9..55dd0570 100644 --- a/model/user.go +++ b/model/user.go @@ -1,8 +1,10 @@ package model import ( - "errors" "fmt" + "strings" + + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" @@ -10,7 +12,6 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" "gorm.io/gorm" - "strings" ) const ( @@ -39,6 +40,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -245,6 +247,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -277,6 +287,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/monitor/manage.go b/monitor/manage.go index 946e78af..268d3924 100644 --- a/monitor/manage.go +++ b/monitor/manage.go @@ -1,10 +1,11 @@ package monitor import ( - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/relay/model" "net/http" "strings" + + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" ) func ShouldDisableChannel(err *model.Error, statusCode int) bool { @@ -18,31 +19,23 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { return true } switch err.Type { - case "insufficient_quota": - return true - // https://docs.anthropic.com/claude/reference/errors - case "authentication_error": - return true - case "permission_error": - return true - case "forbidden": + case "insufficient_quota", "authentication_error", "permission_error", "forbidden": return true } if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } - if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { - return true - } - //if strings.Contains(err.Message, "quota") { - // return true - //} - if strings.Contains(err.Message, "credit") { - return true - } - if strings.Contains(err.Message, "balance") { + + lowerMessage := strings.ToLower(err.Message) + if strings.Contains(lowerMessage, "your access was terminated") || + strings.Contains(lowerMessage, "violation of our policies") || + strings.Contains(lowerMessage, "your credit balance is too low") || + strings.Contains(lowerMessage, "organization has been disabled") || + strings.Contains(lowerMessage, "credit") || + strings.Contains(lowerMessage, "balance") || + strings.Contains(lowerMessage, "permission denied") || + strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "已欠费") { return true } return false diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 00000000..e10a6764 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "one-api", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/relay/adaptor.go b/relay/adaptor.go index 711e63bd..d6fbf058 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -16,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/proxy" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" "github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" @@ -61,6 +62,9 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &vertexai.Adaptor{} case apitype.Proxy: return &proxy.Adaptor{} + case apitype.Replicate: + return &replicate.Adaptor{} } + return nil } diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index 42d49c0a..99456abf 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -1,14 +1,15 @@ package aiproxy import ( - "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) type Adaptor struct { @@ -38,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return aiProxyLibraryRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go index d64b6809..03eeefa4 100644 --- a/relay/adaptor/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -4,7 +4,6 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strconv" @@ -15,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index 4aa8a11a..25fbc704 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -3,13 +3,14 @@ package ali import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" ) // https://help.aliyun.com/zh/dashscope/developer-reference/api-details @@ -67,7 +68,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/ali/constants.go b/relay/adaptor/ali/constants.go index 3f24ce2e..f3d99520 100644 --- a/relay/adaptor/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -1,7 +1,23 @@ package ali var ModelList = []string{ - "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", - "text-embedding-v1", + "qwen-turbo", "qwen-turbo-latest", + "qwen-plus", "qwen-plus-latest", + "qwen-max", "qwen-max-latest", + "qwen-max-longcontext", + "qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest", + "qwen-vl-ocr", "qwen-vl-ocr-latest", + "qwen-audio-turbo", + "qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest", + "qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest", + "qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct", + "qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct", + "qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat", + "qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat", + "qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1", + "qwen2-audio-instruct", "qwen-audio-chat", + "qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct", + "qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct", + "text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", } diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go index 8261803d..a251d4d8 100644 --- a/relay/adaptor/ali/image.go +++ b/relay/adaptor/ali/image.go @@ -5,15 +5,16 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "time" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" - "time" ) func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index f9039dbe..78cd11ca 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,15 +3,16 @@ package ali import ( "bufio" "encoding/json" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" ) @@ -35,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } - if request.TopP >= 1 { - request.TopP = 0.9999 - } + request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) return &ChatRequest{ Model: aliModel, Input: Input{ @@ -59,7 +58,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { return &EmbeddingRequest{ - Model: "text-embedding-v1", + Model: request.Model, Input: struct { Texts []string `json:"texts"` }{ @@ -102,8 +101,9 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat StatusCode: resp.StatusCode, }, nil } - + requestModel := c.GetString(ctxkey.RequestModel) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) + fullTextResponse.Model = requestModel jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go index 450b5f52..a680c7e2 100644 --- a/relay/adaptor/ali/model.go +++ b/relay/adaptor/ali/model.go @@ -16,13 +16,13 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` ResultFormat string `json:"result_format,omitempty"` Tools []model.Tool `json:"tools,omitempty"` } diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index bd0949be..76557b31 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -1,13 +1,13 @@ package anthropic import ( - "errors" "fmt" "io" "net/http" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -20,6 +20,8 @@ func (a *Adaptor) Init(meta *meta.Meta) { } +// https://docs.anthropic.com/claude/reference/messages_post +// anthopic migrate to Message API func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } @@ -47,10 +49,12 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } + + c.Set("claude_model", request.Model) return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index 143d1efc..8ea7c4d8 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -3,7 +3,10 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-3-haiku-20240307", + "claude-3-5-haiku-20241022", "claude-3-sonnet-20240229", "claude-3-opus-20240229", "claude-3-5-sonnet-20240620", + "claude-3-5-sonnet-20241022", + "claude-3-5-sonnet-latest", } diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index d3e306c8..22cb4aad 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -261,7 +261,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } return 0, nil, nil }) - common.SetEventStreamHeaders(c) var usage model.Usage diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 47f76629..a6566d68 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -48,8 +48,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` @@ -66,6 +66,20 @@ type Error struct { Message string `json:"message"` } +type ResponseType string + +const ( + TypeError ResponseType = "error" + TypeStart ResponseType = "message_start" + TypeContentStart ResponseType = "content_block_start" + TypeContent ResponseType = "content_block_delta" + TypePing ResponseType = "ping" + TypeContentStop ResponseType = "content_block_stop" + TypeMessageDelta ResponseType = "message_delta" + TypeMessageStop ResponseType = "message_stop" +) + +// https://docs.anthropic.com/claude/reference/messages-streaming type Response struct { Id string `json:"id"` Type string `json:"type"` diff --git a/relay/adaptor/aws/adaptor.go b/relay/adaptor/aws/adaptor.go index 62221346..9a85f510 100644 --- a/relay/adaptor/aws/adaptor.go +++ b/relay/adaptor/aws/adaptor.go @@ -1,7 +1,6 @@ package aws import ( - "errors" "io" "net/http" @@ -9,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" "github.com/songquanpeng/one-api/relay/meta" @@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 7142e46f..3fe3dfd8 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -29,10 +29,13 @@ var AwsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", "claude-2.1": "anthropic.claude-v2:1", - "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", } func awsModelID(requestModel string) (string, error) { diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go index 6d00b688..10622887 100644 --- a/relay/adaptor/aws/claude/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -11,8 +11,8 @@ type Request struct { Messages []anthropic.Message `json:"messages"` System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go index 7b86c3b8..6cb64cde 100644 --- a/relay/adaptor/aws/llama3/model.go +++ b/relay/adaptor/aws/llama3/model.go @@ -4,10 +4,10 @@ package aws // // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html type Request struct { - Prompt string `json:"prompt"` - MaxGenLen int `json:"max_gen_len,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` } // Response is the response from AWS Llama3 diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go index 4cb880f2..f5fc0038 100644 --- a/relay/adaptor/aws/utils/adaptor.go +++ b/relay/adaptor/aws/utils/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go index 15306b95..6587b575 100644 --- a/relay/adaptor/baidu/adaptor.go +++ b/relay/adaptor/baidu/adaptor.go @@ -3,15 +3,15 @@ package baidu import ( "errors" "fmt" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index ebe70c32..d96a2377 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -16,6 +15,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -35,9 +35,9 @@ type Message struct { type ChatRequest struct { Messages []Message `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - PenaltyScore float64 `json:"penalty_score,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + PenaltyScore *float64 `json:"penalty_score,omitempty"` Stream bool `json:"stream,omitempty"` System string `json:"system,omitempty"` DisableSearch bool `json:"disable_search,omitempty"` diff --git a/relay/adaptor/baidu/model.go b/relay/adaptor/baidu/model.go index cc1feb2f..8c55c18d 100644 --- a/relay/adaptor/baidu/model.go +++ b/relay/adaptor/baidu/model.go @@ -1,8 +1,9 @@ package baidu import ( - "github.com/songquanpeng/one-api/relay/model" "time" + + "github.com/songquanpeng/one-api/relay/model" ) type ChatResponse struct { diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 97e3dbb2..8958466d 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -19,7 +19,7 @@ type Adaptor struct { } // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.New("not implemented") } diff --git a/relay/adaptor/cloudflare/model.go b/relay/adaptor/cloudflare/model.go index 0d3bafe0..8e382ba7 100644 --- a/relay/adaptor/cloudflare/model.go +++ b/relay/adaptor/cloudflare/model.go @@ -9,5 +9,5 @@ type Request struct { Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` } diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go index 6fdb1b04..dd90bd7b 100644 --- a/relay/adaptor/cohere/adaptor.go +++ b/relay/adaptor/cohere/adaptor.go @@ -15,7 +15,7 @@ import ( type Adaptor struct{} // ConvertImageRequest implements adaptor.Adaptor. -func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.New("not implemented") } diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 45db437b..736c5a8d 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { K: textRequest.TopK, Stream: textRequest.Stream, FrequencyPenalty: textRequest.FrequencyPenalty, - PresencePenalty: textRequest.FrequencyPenalty, + PresencePenalty: textRequest.PresencePenalty, Seed: int(textRequest.Seed), } if cohereRequest.Model == "" { diff --git a/relay/adaptor/cohere/model.go b/relay/adaptor/cohere/model.go index 64fa9c94..3a8bc99d 100644 --- a/relay/adaptor/cohere/model.go +++ b/relay/adaptor/cohere/model.go @@ -10,15 +10,15 @@ type Request struct { PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" Connectors []Connector `json:"connectors,omitempty"` Documents []Document `json:"documents,omitempty"` - Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3 + Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 MaxTokens int `json:"max_tokens,omitempty"` MaxInputTokens int `json:"max_input_tokens,omitempty"` K int `json:"k,omitempty"` // 默认值为0 - P float64 `json:"p,omitempty"` // 默认值为0.75 + P *float64 `json:"p,omitempty"` // 默认值为0.75 Seed int `json:"seed,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 - PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 Tools []Tool `json:"tools,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"` } diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 8953d7a3..302e661a 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -1,13 +1,14 @@ package adaptor import ( - "errors" - "fmt" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/relay/meta" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/meta" ) func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { @@ -21,19 +22,24 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { - return nil, fmt.Errorf("get request url failed: %w", err) + return nil, errors.Wrap(err, "get request url failed") } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + + req, err := http.NewRequestWithContext(c.Request.Context(), + c.Request.Method, fullRequestURL, requestBody) if err != nil { - return nil, fmt.Errorf("new request failed: %w", err) + return nil, errors.Wrap(err, "new request failed") } + + req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType)) + err = a.SetupRequestHeader(c, req, meta) if err != nil { - return nil, fmt.Errorf("setup request header failed: %w", err) + return nil, errors.Wrap(err, "setup request header failed") } resp, err := DoRequest(c, req) if err != nil { - return nil, fmt.Errorf("do request failed: %w", err) + return nil, errors.Wrap(err, "do request failed") } return resp, nil } @@ -48,5 +54,6 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { } _ = req.Body.Close() _ = c.Request.Body.Close() + return resp, nil } diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go index 44f560e8..30f43353 100644 --- a/relay/adaptor/coze/adaptor.go +++ b/relay/adaptor/coze/adaptor.go @@ -3,13 +3,14 @@ package coze import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) type Adaptor struct { @@ -38,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go index d018a096..5a03c261 100644 --- a/relay/adaptor/deepl/adaptor.go +++ b/relay/adaptor/deepl/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return convertedRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 12f48c71..33733de3 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -1,12 +1,12 @@ package gemini import ( - "errors" "fmt" "io" "net/http" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" channelhelper "github.com/songquanpeng/one-api/relay/adaptor" @@ -24,7 +24,16 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) + var defaultVersion string + switch meta.ActualModelName { + case "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp": + defaultVersion = "v1beta" + default: + defaultVersion = config.GeminiVersion + } + + version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) action := "" switch meta.Mode { case relaymode.Embeddings: @@ -36,12 +45,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if meta.IsStream { action = "streamGenerateContent?alt=sse" } + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channelhelper.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) + req.URL.Query().Add("key", meta.APIKey) return nil } @@ -59,7 +70,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go index b0f84dfc..9d1cbc4a 100644 --- a/relay/adaptor/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -3,5 +3,9 @@ package gemini // https://ai.google.dev/models/gemini var ModelList = []string{ - "gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", + "gemini-pro", "gemini-1.0-pro", + "gemini-1.5-flash", "gemini-1.5-pro", + "text-embedding-004", "aqa", + "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp", } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 51fd6aa8..28960706 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,22 +4,22 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - - "github.com/gin-gonic/gin" ) // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn @@ -28,6 +28,17 @@ const ( VisionMaxImageNum = 16 ) +var mimeTypeMap = map[string]string{ + "json_object": "application/json", + "text": "text/plain", +} + +var toolChoiceTypeMap = map[string]string{ + "none": "NONE", + "auto": "AUTO", + "required": "ANY", +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { geminiRequest := ChatRequest{ @@ -49,6 +60,10 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Threshold: config.GeminiSafetySetting, }, + { + Category: "HARM_CATEGORY_CIVIC_INTEGRITY", + Threshold: config.GeminiSafetySetting, + }, }, GenerationConfig: ChatGenerationConfig{ Temperature: textRequest.Temperature, @@ -56,6 +71,15 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } + if textRequest.ResponseFormat != nil { + if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok { + geminiRequest.GenerationConfig.ResponseMimeType = mimeType + } + if textRequest.ResponseFormat.JsonSchema != nil { + geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema + geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"] + } + } if textRequest.Tools != nil { functions := make([]model.Function, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -73,7 +97,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { }, } } - shouldAddDummyModelMessage := false + if textRequest.ToolChoice != nil { + geminiRequest.ToolConfig = &ToolConfig{ + FunctionCallingConfig: FunctionCallingConfig{ + Mode: "auto", + }, + } + switch mode := textRequest.ToolChoice.(type) { + case string: + geminiRequest.ToolConfig.FunctionCallingConfig.Mode = toolChoiceTypeMap[mode] + case map[string]interface{}: + geminiRequest.ToolConfig.FunctionCallingConfig.Mode = "ANY" + if fn, ok := mode["function"].(map[string]interface{}); ok { + if name, ok := fn["name"].(string); ok { + geminiRequest.ToolConfig.FunctionCallingConfig.AllowedFunctionNames = []string{name} + } + } + } + } for _, message := range textRequest.Messages { content := ChatContent{ Role: message.Role, @@ -111,25 +152,12 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { if content.Role == "assistant" { content.Role = "model" } - // Converting system prompt to prompt from user for the same reason + // Converting system prompt to SystemInstructions if content.Role == "system" { - content.Role = "user" - shouldAddDummyModelMessage = true + geminiRequest.SystemInstruction = &content + continue } geminiRequest.Contents = append(geminiRequest.Contents, content) - - // If a system message is the last message, we need to add a dummy model message to make gemini happy - if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{ - Role: "model", - Parts: []Part{ - { - Text: "Okay", - }, - }, - }) - shouldAddDummyModelMessage = false - } } return &geminiRequest @@ -167,10 +195,16 @@ func (g *ChatResponse) GetResponseText() string { if g == nil { return "" } - if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { - return g.Candidates[0].Content.Parts[0].Text + var builder strings.Builder + for _, candidate := range g.Candidates { + for idx, part := range candidate.Content.Parts { + if idx > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } } - return "" + return builder.String() } type ChatCandidate struct { @@ -232,7 +266,14 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { if candidate.Content.Parts[0].FunctionCall != nil { choice.Message.ToolCalls = getToolCalls(&candidate) } else { - choice.Message.Content = candidate.Content.Parts[0].Text + var builder strings.Builder + for idx, part := range candidate.Content.Parts { + if idx > 0 { + builder.WriteString("\n") + } + builder.WriteString(part.Text) + } + choice.Message.Content = builder.String() } } else { choice.Message.Content = "" diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index f7179ea4..a19248bc 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -1,10 +1,12 @@ package gemini type ChatRequest struct { - Contents []ChatContent `json:"contents"` - SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` - GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` - Tools []ChatTools `json:"tools,omitempty"` + Contents []ChatContent `json:"contents"` + SystemInstruction *ChatContent `json:"system_instruction,omitempty"` + SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"` + Tools []ChatTools `json:"tools,omitempty"` + ToolConfig *ToolConfig `json:"tool_config,omitempty"` } type EmbeddingRequest struct { @@ -65,10 +67,21 @@ type ChatTools struct { } type ChatGenerationConfig struct { - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` + ResponseSchema any `json:"responseSchema,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type FunctionCallingConfig struct { + Mode string `json:"mode,omitempty"` + AllowedFunctionNames []string `json:"allowed_function_names,omitempty"` +} + +type ToolConfig struct { + FunctionCallingConfig FunctionCallingConfig `json:"function_calling_config"` } diff --git a/relay/adaptor/groq/constants.go b/relay/adaptor/groq/constants.go index 559851ee..0864ebe7 100644 --- a/relay/adaptor/groq/constants.go +++ b/relay/adaptor/groq/constants.go @@ -4,14 +4,24 @@ package groq var ModelList = []string{ "gemma-7b-it", - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", "gemma2-9b-it", - "llama-3.1-405b-reasoning", "llama-3.1-70b-versatile", "llama-3.1-8b-instant", + "llama-3.2-11b-text-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-1b-preview", + "llama-3.2-3b-preview", + "llama-3.2-11b-vision-preview", + "llama-3.2-90b-text-preview", + "llama-3.2-90b-vision-preview", + "llama-guard-3-8b", + "llama3-70b-8192", + "llama3-8b-8192", "llama3-groq-70b-8192-tool-use-preview", "llama3-groq-8b-8192-tool-use-preview", + "llava-v1.5-7b-4096-preview", + "mixtral-8x7b-32768", + "distil-whisper-large-v3-en", "whisper-large-v3", + "whisper-large-v3-turbo", } diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go index 01b2e2cb..88667561 100644 --- a/relay/adaptor/interface.go +++ b/relay/adaptor/interface.go @@ -13,7 +13,7 @@ type Adaptor interface { GetRequestURL(meta *meta.Meta) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - ConvertImageRequest(request *model.ImageRequest) (any, error) + ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go index ad1f8983..35cc2a61 100644 --- a/relay/adaptor/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -1,16 +1,16 @@ package ollama import ( - "errors" "fmt" - "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { @@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index 43317ff6..2462fc3e 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,18 +5,17 @@ import ( "context" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" - "github.com/songquanpeng/one-api/common/helper" - "github.com/songquanpeng/one-api/common/random" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -31,8 +30,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, - NumPredict: request.MaxTokens, - NumCtx: request.NumCtx, + NumPredict: request.MaxTokens, + NumCtx: request.NumCtx, }, Stream: request.Stream, } @@ -122,7 +121,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC for scanner.Scan() { data := scanner.Text() if strings.HasPrefix(data, "}") { - data = strings.TrimPrefix(data, "}") + "}" + data = strings.TrimPrefix(data, "}") + "}" } var ollamaResponse ChatResponse diff --git a/relay/adaptor/ollama/model.go b/relay/adaptor/ollama/model.go index 7039984f..94f2ab73 100644 --- a/relay/adaptor/ollama/model.go +++ b/relay/adaptor/ollama/model.go @@ -1,14 +1,14 @@ package ollama type Options struct { - Seed int `json:"seed,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + Seed int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } type Message struct { diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 5dc395ad..5612105c 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -1,13 +1,14 @@ package openai import ( - "errors" "fmt" "io" "net/http" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/doubao" "github.com/songquanpeng/one-api/relay/adaptor/minimax" @@ -75,10 +76,39 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } + + if config.EnforceIncludeUsage && request.Stream { + // always return usage in stream mode + if request.StreamOptions == nil { + request.StreamOptions = &model.StreamOptions{} + } + request.StreamOptions.IncludeUsage = true + } + + // o1/o1-mini/o1-preview do not support system prompt and max_tokens + if strings.HasPrefix(request.Model, "o1") { + request.MaxTokens = 0 + request.Messages = func(raw []model.Message) (filtered []model.Message) { + for i := range raw { + if raw[i].Role != "system" { + filtered = append(filtered, raw[i]) + } + } + + return + }(request.Messages) + } + + if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") { + // TODO: Since it is not clear how to implement billing in stream mode, + // it is temporarily not supported + return nil, errors.New("stream mode is not supported for gpt-4o-audio") + } + return request, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } @@ -104,10 +134,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.ImagesEdits: + err, _ = ImagesEditsHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } } + return } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go index 0512f05c..15b4dcc0 100644 --- a/relay/adaptor/openai/compatible.go +++ b/relay/adaptor/openai/compatible.go @@ -11,9 +11,10 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" "github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/togetherai" - "github.com/songquanpeng/one-api/relay/adaptor/siliconflow" + "github.com/songquanpeng/one-api/relay/adaptor/xai" "github.com/songquanpeng/one-api/relay/channeltype" ) @@ -32,6 +33,7 @@ var CompatibleChannels = []int{ channeltype.TogetherAI, channeltype.Novita, channeltype.SiliconFlow, + channeltype.XAI, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -64,6 +66,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "novita", novita.ModelList case channeltype.SiliconFlow: return "siliconflow", siliconflow.ModelList + case channeltype.XAI: + return "xai", xai.ModelList default: return "openai", ModelList } diff --git a/relay/adaptor/openai/constants.go b/relay/adaptor/openai/constants.go index 156a50e7..7f116e28 100644 --- a/relay/adaptor/openai/constants.go +++ b/relay/adaptor/openai/constants.go @@ -7,8 +7,9 @@ var ModelList = []string{ "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", - "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20", "chatgpt-4o-latest", "gpt-4o-mini", "gpt-4o-mini-2024-07-18", + "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-12-17", "gpt-4o-audio-preview-2024-10-01", "gpt-4-vision-preview", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", @@ -18,4 +19,7 @@ var ModelList = []string{ "dall-e-2", "dall-e-3", "whisper-1", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", + "o1", "o1-2024-12-17", + "o1-preview", "o1-preview-2024-09-12", + "o1-mini", "o1-mini-2024-09-12", } diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go index 7d73303b..47c2a882 100644 --- a/relay/adaptor/openai/helper.go +++ b/relay/adaptor/openai/helper.go @@ -2,15 +2,16 @@ package openai import ( "fmt" + "strings" + "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/model" - "strings" ) -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { +func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { usage := &model.Usage{} usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.CompletionTokens = CountTokenText(responseText, modelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage } diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go index 0f89618a..433d9421 100644 --- a/relay/adaptor/openai/image.go +++ b/relay/adaptor/openai/image.go @@ -3,12 +3,30 @@ package openai import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" ) +// ImagesEditsHandler just copy response body to client +// +// https://platform.openai.com/docs/api-reference/images/createEdit +func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + c.Writer.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + defer resp.Body.Close() + + return nil, nil +} + func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var imageResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 9ee547b3..f986ed09 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -5,15 +5,16 @@ import ( "bytes" "encoding/json" "io" + "math" "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) @@ -55,8 +56,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E render.StringData(c, data) // if error happened, pass the data to client continue // just ignore the error } - if len(streamResponse.Choices) == 0 { - // but for empty choice, we should not pass it to client, this is for azure + if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { + // but for empty choice and no usage, we should not pass it to client, this is for azure continue // just ignore empty choice } render.StringData(c, data) @@ -96,6 +97,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return nil, responseText, usage } +// Handler handles the non-stream response from OpenAI API func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { var textResponse SlimTextResponse responseBody, err := io.ReadAll(resp.Body) @@ -116,8 +118,10 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st StatusCode: resp.StatusCode, }, nil } + // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. @@ -146,6 +150,24 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } + } else if textResponse.PromptTokensDetails.AudioTokens+textResponse.CompletionTokensDetails.AudioTokens > 0 { + // Convert the more expensive audio tokens to uniformly priced text tokens. + // Note that when there are no audio tokens in prompt and completion, + // OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading. + textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.PromptTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName), + )) + textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), + )) + + textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + + textResponse.Usage.CompletionTokens } + return nil, &textResponse.Usage } diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index 4c974de4..39e87262 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -1,6 +1,10 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "mime/multipart" + + "github.com/songquanpeng/one-api/relay/model" +) type TextContent struct { Type string `json:"type,omitempty"` @@ -71,6 +75,24 @@ type TextToSpeechRequest struct { ResponseFormat string `json:"response_format"` } +type AudioTranscriptionRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Language string `form:"language"` + Prompt string `form:"prompt"` + ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` + TimestampGranularity []string `form:"timestamp_granularity"` +} + +type AudioTranslationRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Prompt string `form:"prompt"` + ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` +} + type UsageOrResponseText struct { *model.Usage ResponseText string diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index 7c8468b9..53ac5297 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -1,16 +1,22 @@ package openai import ( - "errors" + "bytes" + "context" + "encoding/base64" "fmt" - "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/image" - "github.com/songquanpeng/one-api/common/logger" - billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" - "github.com/songquanpeng/one-api/relay/model" "math" "strings" + + "github.com/pkg/errors" + "github.com/pkoukk/tiktoken-go" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/image" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/billing/ratio" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/model" ) // tokenEncoderMap won't grow after initialization @@ -70,8 +76,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func CountTokenMessages(messages []model.Message, model string) int { - tokenEncoder := getTokenEncoder(model) +func CountTokenMessages(ctx context.Context, + messages []model.Message, actualModel string) int { + tokenEncoder := getTokenEncoder(actualModel) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb // https://github.com/pkoukk/tiktoken-go/issues/6 @@ -79,7 +86,7 @@ func CountTokenMessages(messages []model.Message, model string) int { // Every message follows <|start|>{role/name}\n{content}<|end|>\n var tokensPerMessage int var tokensPerName int - if model == "gpt-3.5-turbo-0301" { + if actualModel == "gpt-3.5-turbo-0301" { tokensPerMessage = 4 tokensPerName = -1 // If there's a name, the role is omitted } else { @@ -89,37 +96,38 @@ func CountTokenMessages(messages []model.Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - switch v := message.Content.(type) { - case string: - tokenNum += getTokenNum(tokenEncoder, v) - case []any: - for _, it := range v { - m := it.(map[string]any) - switch m["type"] { - case "text": - if textValue, ok := m["text"]; ok { - if textString, ok := textValue.(string); ok { - tokenNum += getTokenNum(tokenEncoder, textString) - } - } - case "image_url": - imageUrl, ok := m["image_url"].(map[string]any) - if ok { - url := imageUrl["url"].(string) - detail := "" - if imageUrl["detail"] != nil { - detail = imageUrl["detail"].(string) - } - imageTokens, err := countImageTokens(url, detail, model) - if err != nil { - logger.SysError("error counting image tokens: " + err.Error()) - } else { - tokenNum += imageTokens - } - } + contents := message.ParseContent() + for _, content := range contents { + switch content.Type { + case model.ContentTypeText: + tokenNum += getTokenNum(tokenEncoder, content.Text) + case model.ContentTypeImageURL: + imageTokens, err := countImageTokens( + content.ImageURL.Url, + content.ImageURL.Detail, + actualModel) + if err != nil { + logger.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + case model.ContentTypeInputAudio: + audioData, err := base64.StdEncoding.DecodeString(content.InputAudio.Data) + if err != nil { + logger.SysError("error decoding audio data: " + err.Error()) + } + + tokens, err := helper.GetAudioTokens(ctx, + bytes.NewReader(audioData), + ratio.GetAudioPromptTokensPerSecond(actualModel)) + if err != nil { + logger.SysError("error counting audio tokens: " + err.Error()) + } else { + tokenNum += tokens } } } + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -130,6 +138,53 @@ func CountTokenMessages(messages []model.Message, model string) int { return tokenNum } +// func countVisonTokenMessages(messages []VisionMessage, model string) (int, error) { +// tokenEncoder := getTokenEncoder(model) +// // Reference: +// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +// // https://github.com/pkoukk/tiktoken-go/issues/6 +// // +// // Every message follows <|start|>{role/name}\n{content}<|end|>\n +// var tokensPerMessage int +// var tokensPerName int +// if model == "gpt-3.5-turbo-0301" { +// tokensPerMessage = 4 +// tokensPerName = -1 // If there's a name, the role is omitted +// } else { +// tokensPerMessage = 3 +// tokensPerName = 1 +// } +// tokenNum := 0 +// for _, message := range messages { +// tokenNum += tokensPerMessage +// for _, cnt := range message.Content { +// switch cnt.Type { +// case OpenaiVisionMessageContentTypeText: +// tokenNum += getTokenNum(tokenEncoder, cnt.Text) +// case OpenaiVisionMessageContentTypeImageUrl: +// imgblob, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(cnt.ImageUrl.URL, "data:image/jpeg;base64,")) +// if err != nil { +// return 0, errors.Wrap(err, "failed to decode base64 image") +// } + +// if imgtoken, err := CountVisionImageToken(imgblob, cnt.ImageUrl.Detail); err != nil { +// return 0, errors.Wrap(err, "failed to count vision image token") +// } else { +// tokenNum += imgtoken +// } +// } +// } + +// tokenNum += getTokenNum(tokenEncoder, message.Role) +// if message.Name != nil { +// tokenNum += tokensPerName +// tokenNum += getTokenNum(tokenEncoder, *message.Name) +// } +// } +// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> +// return tokenNum, nil +// } + const ( lowDetailCost = 85 highDetailCostPerTile = 170 diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go index ba0cab7d..83beadba 100644 --- a/relay/adaptor/openai/util.go +++ b/relay/adaptor/openai/util.go @@ -1,8 +1,16 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" +) func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) + Error := model.Error{ Message: err.Error(), Type: "one_api_error", diff --git a/relay/adaptor/palm/adaptor.go b/relay/adaptor/palm/adaptor.go index 98aa3e18..65e3d3c5 100644 --- a/relay/adaptor/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -1,9 +1,9 @@ package palm import ( - "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" @@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/palm/model.go b/relay/adaptor/palm/model.go index f653022c..2bdd8f29 100644 --- a/relay/adaptor/palm/model.go +++ b/relay/adaptor/palm/model.go @@ -19,11 +19,11 @@ type Prompt struct { } type ChatRequest struct { - Prompt Prompt `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - CandidateCount int `json:"candidateCount,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` + Prompt Prompt `json:"prompt"` + Temperature *float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` } type Error struct { diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go index 670c7628..32984fc7 100644 --- a/relay/adaptor/proxy/adaptor.go +++ b/relay/adaptor/proxy/adaptor.go @@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { return nil, errors.Errorf("not implement") } diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go new file mode 100644 index 00000000..296a9aaa --- /dev/null +++ b/relay/adaptor/replicate/adaptor.go @@ -0,0 +1,180 @@ +package replicate + +import ( + "bytes" + "fmt" + "io" + "net/http" + "slices" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("should call replicate.ConvertImageRequest instead") +} + +func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N != 1 && request.N != 0 { + return nil, errors.New("only support N=1") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + case relaymode.ImagesEdits: + return convertImageRemixRequest(c) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { + return DrawImageRequest{ + Input: ImageInput{ + Steps: 25, + Prompt: request.Prompt, + Guidance: 3, + Seed: int(time.Now().UnixNano()), + SafetyTolerance: 5, + NImages: 1, // replicate will always return 1 image + Width: 1440, + Height: 1440, + AspectRatio: "1:1", + }, + }, nil +} + +func convertImageRemixRequest(c *gin.Context) (any, error) { + // recover request body + requestBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get request body") + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + rawReq := new(OpenaiImageEditRequest) + if err := c.ShouldBind(rawReq); err != nil { + return nil, errors.Wrap(err, "parse image edit form") + } + + return rawReq.toFluxRemixRequest() +} + +// ConvertRequest converts the request to the format that the target API expects. +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if !request.Stream { + // TODO: support non-stream mode + return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true") + } + + // Build the prompt from OpenAI messages + var promptBuilder strings.Builder + for _, message := range request.Messages { + switch msgCnt := message.Content.(type) { + case string: + promptBuilder.WriteString(message.Role) + promptBuilder.WriteString(": ") + promptBuilder.WriteString(msgCnt) + promptBuilder.WriteString("\n") + default: + } + } + + replicateRequest := ReplicateChatRequest{ + Input: ChatInput{ + Prompt: promptBuilder.String(), + MaxTokens: request.MaxTokens, + Temperature: 1.0, + TopP: 1.0, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + }, + } + + // Map optional fields + if request.Temperature != nil { + replicateRequest.Input.Temperature = *request.Temperature + } + if request.TopP != nil { + replicateRequest.Input.TopP = *request.TopP + } + if request.PresencePenalty != nil { + replicateRequest.Input.PresencePenalty = *request.PresencePenalty + } + if request.FrequencyPenalty != nil { + replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty + } + if request.MaxTokens > 0 { + replicateRequest.Input.MaxTokens = request.MaxTokens + } else if request.MaxTokens == 0 { + replicateRequest.Input.MaxTokens = 500 + } + + return replicateRequest, nil +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + if !slices.Contains(ModelList, meta.OriginModelName) { + return "", errors.Errorf("model %s not supported", meta.OriginModelName) + } + + return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send request to replicate") + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: + err, usage = ImageHandler(c, resp) + case relaymode.ChatCompletions: + err, usage = ChatHandler(c, resp) + default: + err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "replicate" +} diff --git a/relay/adaptor/replicate/chat.go b/relay/adaptor/replicate/chat.go new file mode 100644 index 00000000..4051f85c --- /dev/null +++ b/relay/adaptor/replicate/chat.go @@ -0,0 +1,191 @@ +package replicate + +import ( + "bufio" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/render" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +func ChatHandler(c *gin.Context, resp *http.Response) ( + srvErr *model.ErrorWithStatusCode, usage *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ChatResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ChatResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + if taskData.URLs.Stream == "" { + return errors.New("stream url is empty") + } + + // request stream url + responseText, err := chatStreamHandler(c, taskData.URLs.Stream) + if err != nil { + return errors.Wrap(err, "chat stream handler") + } + + ctxMeta := meta.GetByContext(c) + usage = openai.ResponseText2Usage(responseText, + ctxMeta.ActualModelName, ctxMeta.PromptTokens) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, usage +} + +const ( + eventPrefix = "event: " + dataPrefix = "data: " + done = "[DONE]" +) + +func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) { + // request stream endpoint + streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil) + if err != nil { + return "", errors.Wrap(err, "new request to stream") + } + + streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + streamReq.Header.Set("Accept", "text/event-stream") + streamReq.Header.Set("Cache-Control", "no-store") + + resp, err := http.DefaultClient.Do(streamReq) + if err != nil { + return "", errors.Wrap(err, "do request to stream") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(resp.Body) + return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + common.SetEventStreamHeaders(c) + doneRendered := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Handle comments starting with ':' + if strings.HasPrefix(line, ":") { + continue + } + + // Parse SSE fields + if strings.HasPrefix(line, eventPrefix) { + event := strings.TrimSpace(line[len(eventPrefix):]) + var data string + // Read the following lines to get data and id + for scanner.Scan() { + nextLine := scanner.Text() + if nextLine == "" { + break + } + if strings.HasPrefix(nextLine, dataPrefix) { + data = nextLine[len(dataPrefix):] + } else if strings.HasPrefix(nextLine, "id:") { + // id = strings.TrimSpace(nextLine[len("id:"):]) + } + } + + if event == "output" { + render.StringData(c, data) + responseText += data + } else if event == "done" { + render.Done(c) + doneRendered = true + break + } + } + } + + if err := scanner.Err(); err != nil { + return "", errors.Wrap(err, "scan stream") + } + + if !doneRendered { + render.Done(c) + } + + return responseText, nil +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go new file mode 100644 index 00000000..989142c9 --- /dev/null +++ b/relay/adaptor/replicate/constant.go @@ -0,0 +1,58 @@ +package replicate + +// ModelList is a list of models that can be used with Replicate. +// +// https://replicate.com/pricing +var ModelList = []string{ + // ------------------------------------- + // image model + // ------------------------------------- + "black-forest-labs/flux-1.1-pro", + "black-forest-labs/flux-1.1-pro-ultra", + "black-forest-labs/flux-canny-dev", + "black-forest-labs/flux-canny-pro", + "black-forest-labs/flux-depth-dev", + "black-forest-labs/flux-depth-pro", + "black-forest-labs/flux-dev", + "black-forest-labs/flux-dev-lora", + "black-forest-labs/flux-fill-dev", + "black-forest-labs/flux-fill-pro", + "black-forest-labs/flux-pro", + "black-forest-labs/flux-redux-dev", + "black-forest-labs/flux-redux-schnell", + "black-forest-labs/flux-schnell", + "black-forest-labs/flux-schnell-lora", + "ideogram-ai/ideogram-v2", + "ideogram-ai/ideogram-v2-turbo", + "recraft-ai/recraft-v3", + "recraft-ai/recraft-v3-svg", + "stability-ai/stable-diffusion-3", + "stability-ai/stable-diffusion-3.5-large", + "stability-ai/stable-diffusion-3.5-large-turbo", + "stability-ai/stable-diffusion-3.5-medium", + // ------------------------------------- + // language model + // ------------------------------------- + "ibm-granite/granite-20b-code-instruct-8k", + "ibm-granite/granite-3.0-2b-instruct", + "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k", + "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3.1-405b-instruct", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1", + "mistralai/mixtral-8x7b-instruct-v0.1", + // ------------------------------------- + // video model + // ------------------------------------- + // "minimax/video-01", // TODO: implement the adaptor +} diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go new file mode 100644 index 00000000..d4f7f749 --- /dev/null +++ b/relay/adaptor/replicate/image.go @@ -0,0 +1,207 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "image" + "image/png" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "golang.org/x/image/webp" + "golang.org/x/sync/errgroup" +) + +var errNextLoop = errors.New("next_loop") + +// ImageHandler handles the response from the image creation or remix request +func ImageHandler(c *gin.Context, resp *http.Response) ( + *model.ErrorWithStatusCode, *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ImageResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ImageResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + output, err := taskData.GetOutput() + if err != nil { + return errors.Wrap(err, "get output") + } + if len(output) == 0 { + return errors.New("response output is empty") + } + + var mu sync.Mutex + var pool errgroup.Group + respBody := &openai.ImageResponse{ + Created: taskData.CompletedAt.Unix(), + Data: []openai.ImageData{}, + } + + for _, imgOut := range output { + imgOut := imgOut + pool.Go(func() error { + // download image + downloadReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, imgOut, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + imgResp, err := http.DefaultClient.Do(downloadReq) + if err != nil { + return errors.Wrap(err, "download image") + } + defer imgResp.Body.Close() + + if imgResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(imgResp.Body) + return errors.Errorf("bad status code [%d]%s", + imgResp.StatusCode, string(payload)) + } + + imgData, err := io.ReadAll(imgResp.Body) + if err != nil { + return errors.Wrap(err, "read image") + } + + imgData, err = ConvertImageToPNG(imgData) + if err != nil { + return errors.Wrap(err, "convert image") + } + + mu.Lock() + respBody.Data = append(respBody.Data, openai.ImageData{ + B64Json: fmt.Sprintf("data:image/png;base64,%s", + base64.StdEncoding.EncodeToString(imgData)), + }) + mu.Unlock() + + return nil + }) + } + + if err := pool.Wait(); err != nil { + if len(respBody.Data) == 0 { + return errors.WithStack(err) + } + + logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) + } + + c.JSON(http.StatusOK, respBody) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, nil +} + +// ConvertImageToPNG converts a WebP image to PNG format +func ConvertImageToPNG(webpData []byte) ([]byte, error) { + // bypass if it's already a PNG image + if bytes.HasPrefix(webpData, []byte("\x89PNG")) { + return webpData, nil + } + + // check if is jpeg, convert to png + if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { + img, _, err := image.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode jpeg") + } + + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil + } + + // Decode the WebP image + img, err := webp.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode webp") + } + + // Encode the image as PNG + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil +} diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go new file mode 100644 index 00000000..04fe5277 --- /dev/null +++ b/relay/adaptor/replicate/model.go @@ -0,0 +1,277 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "image" + "image/png" + "io" + "mime/multipart" + "time" + + "github.com/pkg/errors" +) + +type OpenaiImageEditRequest struct { + Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` + Model string `json:"model" form:"model" binding:"required"` + N int `json:"n" form:"n" binding:"min=0,max=10"` + Size string `json:"size" form:"size"` + ResponseFormat string `json:"response_format" form:"response_format"` +} + +// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request. +// +// Note that the mask formats of OpenAI and Flux are different: +// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0), +// while Flux sets the parts to be modified as black (255, 255, 255, 255), +// so we need to convert the format here. +// +// Both OpenAI's Image and Mask are browser-native ImageData, +// which need to be converted to base64 dataURI format. +func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) { + if r.ResponseFormat != "b64_json" { + return nil, errors.New("response_format must be b64_json for replicate models") + } + + fluxReq := &InpaintingImageByFlusReplicateRequest{ + Input: FluxInpaintingInput{ + Prompt: r.Prompt, + Seed: int(time.Now().UnixNano()), + Steps: 30, + Guidance: 3, + SafetyTolerance: 5, + PromptUnsampling: false, + OutputFormat: "png", + }, + } + + imgFile, err := r.Image.Open() + if err != nil { + return nil, errors.Wrap(err, "open image file") + } + defer imgFile.Close() + imgData, err := io.ReadAll(imgFile) + if err != nil { + return nil, errors.Wrap(err, "read image file") + } + + maskFile, err := r.Mask.Open() + if err != nil { + return nil, errors.Wrap(err, "open mask file") + } + defer maskFile.Close() + + // Convert image to base64 + imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) + fluxReq.Input.Image = imageBase64 + + // Convert mask data to RGBA + maskPNG, err := png.Decode(maskFile) + if err != nil { + return nil, errors.Wrap(err, "decode mask file") + } + + // convert mask to RGBA + var maskRGBA *image.RGBA + switch converted := maskPNG.(type) { + case *image.RGBA: + maskRGBA = converted + default: + // Convert to RGBA + bounds := maskPNG.Bounds() + maskRGBA = image.NewRGBA(bounds) + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + maskRGBA.Set(x, y, maskPNG.At(x, y)) + } + } + } + + maskData := maskRGBA.Pix + invertedMask := make([]byte, len(maskData)) + for i := 0; i+4 <= len(maskData); i += 4 { + // If pixel is transparent (alpha = 0), make it black (255) + if maskData[i+3] == 0 { + invertedMask[i] = 255 // R + invertedMask[i+1] = 255 // G + invertedMask[i+2] = 255 // B + invertedMask[i+3] = 255 // A + } else { + // Copy original pixel + copy(invertedMask[i:i+4], maskData[i:i+4]) + } + } + + // Convert inverted mask to base64 encoded png image + invertedMaskRGBA := &image.RGBA{ + Pix: invertedMask, + Stride: maskRGBA.Stride, + Rect: maskRGBA.Rect, + } + + var buf bytes.Buffer + err = png.Encode(&buf, invertedMaskRGBA) + if err != nil { + return nil, errors.Wrap(err, "encode inverted mask to png") + } + + invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()) + fluxReq.Input.Mask = invertedMaskBase64 + + return fluxReq, nil +} + +// DrawImageRequest draw image by fluxpro +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type DrawImageRequest struct { + Input ImageInput `json:"input"` +} + +// ImageInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema +type ImageInput struct { + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + ImagePrompt string `json:"image_prompt"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + Interval int `json:"interval" binding:"required,min=1,max=4"` + AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + Seed int `json:"seed"` + NImages int `json:"n_images" binding:"required,min=1,max=8"` + Width int `json:"width" binding:"required,min=256,max=1440"` + Height int `json:"height" binding:"required,min=256,max=1440"` +} + +// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type InpaintingImageByFlusReplicateRequest struct { + Input FluxInpaintingInput `json:"input"` +} + +// FluxInpaintingInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type FluxInpaintingInput struct { + Mask string `json:"mask" binding:"required"` + Image string `json:"image" binding:"required"` + Seed int `json:"seed"` + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + OutputFormat string `json:"output_format"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + PromptUnsampling bool `json:"prompt_unsampling"` +} + +// ImageResponse is response of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type ImageResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input DrawImageRequest `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output any `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs FluxURLs `json:"urls"` + Version string `json:"version"` +} + +func (r *ImageResponse) GetOutput() ([]string, error) { + switch v := r.Output.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + case nil: + return nil, nil + case []interface{}: + // convert []interface{} to []string + ret := make([]string, len(v)) + for idx, vv := range v { + if vvv, ok := vv.(string); ok { + ret[idx] = vvv + } else { + return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) + } + } + + return ret, nil + default: + return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) + } +} + +// FluxMetrics is metrics of ImageResponse +type FluxMetrics struct { + ImageCount int `json:"image_count"` + PredictTime float64 `json:"predict_time"` + TotalTime float64 `json:"total_time"` +} + +// FluxURLs is urls of ImageResponse +type FluxURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +type ReplicateChatRequest struct { + Input ChatInput `json:"input" form:"input" binding:"required"` +} + +// ChatInput is input of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema +type ChatInput struct { + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + MinTokens int `json:"min_tokens"` + Temperature float64 `json:"temperature"` + SystemPrompt string `json:"system_prompt"` + StopSequences string `json:"stop_sequences"` + PromptTemplate string `json:"prompt_template"` + PresencePenalty float64 `json:"presence_penalty"` + FrequencyPenalty float64 `json:"frequency_penalty"` +} + +// ChatResponse is response of ChatByReplicateRequest +// +// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json +type ChatResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input ChatInput `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output []string `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs ChatResponseUrl `json:"urls"` + Version string `json:"version"` +} + +// ChatResponseUrl is task urls of ChatResponse +type ChatResponseUrl struct { + Stream string `json:"stream"` + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go new file mode 100644 index 00000000..6cde5e94 --- /dev/null +++ b/relay/adaptor/replicate/model_test.go @@ -0,0 +1,106 @@ +package replicate + +import ( + "bytes" + "image" + "image/draw" + "image/png" + "io" + "mime/multipart" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +type nopCloser struct { + io.Reader +} + +func (n nopCloser) Close() error { return nil } + +// Custom FileHeader to override Open method +type customFileHeader struct { + *multipart.FileHeader + openFunc func() (multipart.File, error) +} + +func (c *customFileHeader) Open() (multipart.File, error) { + return c.openFunc() +} + +func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) { + // Create a simple image for testing + img := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var imgBuf bytes.Buffer + err := png.Encode(&imgBuf, img) + require.NoError(t, err) + + // Create a simple mask for testing + mask := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var maskBuf bytes.Buffer + err = png.Encode(&maskBuf, mask) + require.NoError(t, err) + + // Create a multipart.FileHeader from the image and mask bytes + imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes()) + require.NoError(t, err) + maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes()) + require.NoError(t, err) + + req := &OpenaiImageEditRequest{ + Image: imgFileHeader, + Mask: maskFileHeader, + Prompt: "Test prompt", + Model: "test-model", + ResponseFormat: "b64_json", + } + + fluxReq, err := req.toFluxRemixRequest() + require.NoError(t, err) + require.NotNil(t, fluxReq) + require.Equal(t, req.Prompt, fluxReq.Input.Prompt) + require.NotEmpty(t, fluxReq.Input.Image) + require.NotEmpty(t, fluxReq.Input.Mask) +} + +// createFileHeader creates a multipart.FileHeader from file bytes +func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Create a form file field + part, err := writer.CreateFormFile(fieldname, filename) + if err != nil { + return nil, err + } + + // Write the file bytes to the form file field + _, err = part.Write(fileBytes) + if err != nil { + return nil, err + } + + // Close the writer to finalize the form + err = writer.Close() + if err != nil { + return nil, err + } + + // Parse the multipart form + req := &http.Request{ + Header: http.Header{}, + Body: io.NopCloser(body), + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = req.ParseMultipartForm(int64(body.Len())) + if err != nil { + return nil, err + } + + // Retrieve the file header from the parsed form + fileHeader := req.MultipartForm.File[fieldname][0] + return fileHeader, nil +} diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go index a82e562b..6a2346ca 100644 --- a/relay/adaptor/stepfun/constants.go +++ b/relay/adaptor/stepfun/constants.go @@ -1,7 +1,13 @@ package stepfun var ModelList = []string{ + "step-1-8k", "step-1-32k", + "step-1-128k", + "step-1-256k", + "step-1-flash", + "step-2-16k", + "step-1v-8k", "step-1v-32k", - "step-1-200k", + "step-1x-medium", } diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 0de92d4a..9d086eed 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -2,16 +2,17 @@ package tencent import ( "errors" + "io" + "net/http" + "strconv" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" ) // https://cloud.tencent.com/document/api/1729/101837 @@ -58,7 +59,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return tencentRequest, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index be415a94..e8631e5f 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -5,4 +5,5 @@ var ModelList = []string{ "hunyuan-standard", "hunyuan-standard-256K", "hunyuan-pro", + "hunyuan-vision", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 365e33ae..c402e543 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strconv" @@ -21,6 +20,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { Model: &request.Model, Stream: &request.Stream, Messages: messages, - TopP: &request.TopP, - Temperature: &request.Temperature, + TopP: request.TopP, + Temperature: request.Temperature, } } diff --git a/relay/adaptor/vertexai/adaptor.go b/relay/adaptor/vertexai/adaptor.go index 3fab4a45..131506eb 100644 --- a/relay/adaptor/vertexai/adaptor.go +++ b/relay/adaptor/vertexai/adaptor.go @@ -1,18 +1,20 @@ package vertexai import ( - "errors" "fmt" "io" "net/http" + "slices" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - relaymodel "github.com/songquanpeng/one-api/relay/model" + relayModel "github.com/songquanpeng/one-api/relay/model" ) var _ adaptor.Adaptor = new(Adaptor) @@ -24,14 +26,27 @@ type Adaptor struct{} func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") } - adaptor := GetAdaptor(request.Model) + adaptor := GetAdaptor(meta.ActualModelName) if adaptor == nil { - return nil, errors.New("adaptor not found") + return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName) + } + + return adaptor.ConvertImageRequest(c, request) +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + meta := meta.GetByContext(c) + + adaptor := GetAdaptor(meta.ActualModelName) + if adaptor == nil { + return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName) } return adaptor.ConvertRequest(c, relayMode, request) @@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { adaptor := GetAdaptor(meta.ActualModelName) if adaptor == nil { - return nil, &relaymodel.ErrorWithStatusCode{ + return nil, &relayModel.ErrorWithStatusCode{ StatusCode: http.StatusInternalServerError, - Error: relaymodel.Error{ + Error: relayModel.Error{ Message: "adaptor not found", }, } @@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - suffix := "" - if strings.HasPrefix(meta.ActualModelName, "gemini") { + var suffix string + switch { + case strings.HasPrefix(meta.ActualModelName, "gemini"): if meta.IsStream { suffix = "streamGenerateContent?alt=sse" } else { suffix = "generateContent" } - } else { + case slices.Contains(imagen.ModelList, meta.ActualModelName): + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict", + meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region, + ), nil + default: if meta.IsStream { suffix = "streamRawPredict?alt=sse" } else { @@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { suffix, ), nil } + return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", meta.Config.Region, @@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me return nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") - } - return request, nil -} - func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index b39e2dda..14733924 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -13,7 +13,12 @@ import ( ) var ModelList = []string{ - "claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229", + "claude-3-haiku@20240307", + "claude-3-opus@20240229", + "claude-3-sonnet@20240229", + "claude-3-5-sonnet@20240620", + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-haiku@20241022", } const anthropicVersion = "vertex-2023-10-16" @@ -45,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return req, nil } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("not support image request") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = anthropic.StreamHandler(c, resp) diff --git a/relay/adaptor/vertexai/claude/model.go b/relay/adaptor/vertexai/claude/model.go index e1bd5dd4..c08ba460 100644 --- a/relay/adaptor/vertexai/claude/model.go +++ b/relay/adaptor/vertexai/claude/model.go @@ -11,8 +11,8 @@ type Request struct { MaxTokens int `json:"max_tokens,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index 43e6cbcd..871a616f 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,10 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", + "gemini-pro", "gemini-pro-vision", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", + "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", } type Adaptor struct { @@ -32,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return geminiRequest, nil } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + return nil, errors.New("not support image request") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string diff --git a/relay/adaptor/vertexai/imagen/adaptor.go b/relay/adaptor/vertexai/imagen/adaptor.go new file mode 100644 index 00000000..0dc428ec --- /dev/null +++ b/relay/adaptor/vertexai/imagen/adaptor.go @@ -0,0 +1,107 @@ +package imagen + +import ( + "encoding/json" + "io" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +var ModelList = []string{ + "imagen-3.0-generate-001", +} + +type Adaptor struct { +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N <= 0 { + return nil, errors.New("n must be greater than 0") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + case relaymode.ImagesEdits: + return nil, errors.New("not implemented") + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { + return CreateImageRequest{ + Instances: []createImageInstance{ + { + Prompt: request.Prompt, + }, + }, + Parameters: createImageParameters{ + SampleCount: request.N, + }, + }, nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, openai.ErrorWrapper( + errors.Wrap(err, "failed to read response body"), + "read_response_body", + http.StatusInternalServerError, + ) + } + + if resp.StatusCode != http.StatusOK { + return nil, openai.ErrorWrapper( + errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)), + "upstream_response", + http.StatusInternalServerError, + ) + } + + imagenResp := new(CreateImageResponse) + if err := json.Unmarshal(respBody, imagenResp); err != nil { + return nil, openai.ErrorWrapper( + errors.Wrap(err, "failed to decode response body"), + "unmarshal_upstream_response", + http.StatusInternalServerError, + ) + } + + if len(imagenResp.Predictions) == 0 { + return nil, openai.ErrorWrapper( + errors.New("empty predictions"), + "empty_predictions", + http.StatusInternalServerError, + ) + } + + oaiResp := openai.ImageResponse{ + Created: time.Now().Unix(), + } + for _, prediction := range imagenResp.Predictions { + oaiResp.Data = append(oaiResp.Data, openai.ImageData{ + B64Json: prediction.BytesBase64Encoded, + }) + } + + c.JSON(http.StatusOK, oaiResp) + return nil, nil +} diff --git a/relay/adaptor/vertexai/imagen/model.go b/relay/adaptor/vertexai/imagen/model.go new file mode 100644 index 00000000..b890d30d --- /dev/null +++ b/relay/adaptor/vertexai/imagen/model.go @@ -0,0 +1,23 @@ +package imagen + +type CreateImageRequest struct { + Instances []createImageInstance `json:"instances" binding:"required,min=1"` + Parameters createImageParameters `json:"parameters" binding:"required"` +} + +type createImageInstance struct { + Prompt string `json:"prompt"` +} + +type createImageParameters struct { + SampleCount int `json:"sample_count" binding:"required,min=1"` +} + +type CreateImageResponse struct { + Predictions []createImageResponsePrediction `json:"predictions"` +} + +type createImageResponsePrediction struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` +} diff --git a/relay/adaptor/vertexai/registry.go b/relay/adaptor/vertexai/registry.go index 41099f02..37dd06ef 100644 --- a/relay/adaptor/vertexai/registry.go +++ b/relay/adaptor/vertexai/registry.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude" gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -13,8 +14,9 @@ import ( type VertexAIModelType int const ( - VerterAIClaude VertexAIModelType = iota + 1 - VerterAIGemini + VertexAIClaude VertexAIModelType = iota + 1 + VertexAIGemini + VertexAIImagen ) var modelMapping = map[string]VertexAIModelType{} @@ -23,27 +25,35 @@ var modelList = []string{} func init() { modelList = append(modelList, claude.ModelList...) for _, model := range claude.ModelList { - modelMapping[model] = VerterAIClaude + modelMapping[model] = VertexAIClaude } modelList = append(modelList, gemini.ModelList...) for _, model := range gemini.ModelList { - modelMapping[model] = VerterAIGemini + modelMapping[model] = VertexAIGemini + } + + modelList = append(modelList, imagen.ModelList...) + for _, model := range imagen.ModelList { + modelMapping[model] = VertexAIImagen } } type innerAIAdapter interface { ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) } func GetAdaptor(model string) innerAIAdapter { adaptorType := modelMapping[model] switch adaptorType { - case VerterAIClaude: + case VertexAIClaude: return &claude.Adaptor{} - case VerterAIGemini: + case VertexAIGemini: return &gemini.Adaptor{} + case VertexAIImagen: + return &imagen.Adaptor{} default: return nil } diff --git a/relay/adaptor/xai/constants.go b/relay/adaptor/xai/constants.go new file mode 100644 index 00000000..9082b999 --- /dev/null +++ b/relay/adaptor/xai/constants.go @@ -0,0 +1,5 @@ +package xai + +var ModelList = []string{ + "grok-beta", +} diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go index b5967f26..62aa7240 100644 --- a/relay/adaptor/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -2,14 +2,15 @@ package xunfei import ( "errors" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" ) type Adaptor struct { @@ -39,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, nil } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/xunfei/constants.go b/relay/adaptor/xunfei/constants.go index 12a56210..5b82ac29 100644 --- a/relay/adaptor/xunfei/constants.go +++ b/relay/adaptor/xunfei/constants.go @@ -5,6 +5,8 @@ var ModelList = []string{ "SparkDesk-v1.1", "SparkDesk-v2.1", "SparkDesk-v3.1", + "SparkDesk-v3.1-128K", "SparkDesk-v3.5", + "SparkDesk-v3.5-32K", "SparkDesk-v4.0", } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index ef6120e5..3984ba5a 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, } func parseAPIVersionByModelName(modelName string) string { - parts := strings.Split(modelName, "-") - if len(parts) == 2 { - return parts[1] + index := strings.IndexAny(modelName, "-") + if index != -1 { + return modelName[index+1:] } return "" } @@ -283,13 +283,17 @@ func parseAPIVersionByModelName(modelName string) string { func apiVersion2domain(apiVersion string) string { switch apiVersion { case "v1.1": - return "general" + return "lite" case "v2.1": return "generalv2" case "v3.1": return "generalv3" + case "v3.1-128K": + return "pro-128k" case "v3.5": return "generalv3.5" + case "v3.5-32K": + return "max-32k" case "v4.0": return "4.0Ultra" } @@ -297,7 +301,17 @@ func apiVersion2domain(apiVersion string) string { } func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { + var authUrl string domain := apiVersion2domain(apiVersion) - authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + switch apiVersion { + case "v3.1-128K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret) + break + case "v3.5-32K": + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret) + break + default: + authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + } return domain, authUrl } diff --git a/relay/adaptor/xunfei/model.go b/relay/adaptor/xunfei/model.go index 1f37c046..c9fb1bb8 100644 --- a/relay/adaptor/xunfei/model.go +++ b/relay/adaptor/xunfei/model.go @@ -19,11 +19,11 @@ type ChatRequest struct { } `json:"header"` Parameter struct { Chat struct { - Domain string `json:"domain,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Auditing bool `json:"auditing,omitempty"` + Domain string `json:"domain,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` } `json:"chat"` } `json:"parameter"` Payload struct { diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3..9a535e6b 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -3,16 +3,17 @@ package zhipu import ( "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "math" - "net/http" - "strings" ) type Adaptor struct { @@ -65,13 +66,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) return baiduEmbeddingRequest, err default: - // TopP (0.0, 1.0) - request.TopP = math.Min(0.99, request.TopP) - request.TopP = math.Max(0.01, request.TopP) + // TopP [0.0, 1.0] + request.TopP = helper.Float64PtrMax(request.TopP, 1) + request.TopP = helper.Float64PtrMin(request.TopP, 0) - // Temperature (0.0, 1.0) - request.Temperature = math.Min(0.99, request.Temperature) - request.Temperature = math.Max(0.01, request.Temperature) + // Temperature [0.0, 1.0] + request.Temperature = helper.Float64PtrMax(request.Temperature, 1) + request.Temperature = helper.Float64PtrMin(request.Temperature, 0) a.SetVersionByModeName(request.Model) if a.APIVersion == "v4" { return request, nil @@ -80,7 +81,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go index ab3a5678..0489136e 100644 --- a/relay/adaptor/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -3,7 +3,6 @@ package zhipu import ( "bufio" "encoding/json" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -15,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" diff --git a/relay/adaptor/zhipu/model.go b/relay/adaptor/zhipu/model.go index f91de1dc..9a692c01 100644 --- a/relay/adaptor/zhipu/model.go +++ b/relay/adaptor/zhipu/model.go @@ -1,8 +1,9 @@ package zhipu import ( - "github.com/songquanpeng/one-api/relay/model" "time" + + "github.com/songquanpeng/one-api/relay/model" ) type Message struct { @@ -12,8 +13,8 @@ type Message struct { type Request struct { Prompt []Message `json:"prompt"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` RequestId string `json:"request_id,omitempty"` Incremental bool `json:"incremental,omitempty"` } diff --git a/relay/apitype/define.go b/relay/apitype/define.go index cf7b6a0d..0c6a5ff1 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -19,6 +19,7 @@ const ( DeepL VertexAI Proxy + Replicate Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go index ced0c667..c8c42a15 100644 --- a/relay/billing/ratio/image.go +++ b/relay/billing/ratio/image.go @@ -30,6 +30,14 @@ var ImageSizeRatios = map[string]map[string]float64{ "720x1280": 1, "1280x720": 1, }, + "step-1x-medium": { + "256x256": 1, + "512x512": 1, + "768x768": 1, + "1024x1024": 1, + "1280x800": 1, + "800x1280": 1, + }, } var ImageGenerationAmounts = map[string][2]int{ @@ -39,6 +47,7 @@ var ImageGenerationAmounts = map[string][2]int{ "ali-stable-diffusion-v1.5": {1, 4}, // Ali "wanx-v1": {1, 4}, // Ali "cogview-3": {1, 1}, + "step-1x-medium": {1, 1}, } var ImagePromptLengthLimitations = map[string]int{ @@ -48,6 +57,7 @@ var ImagePromptLengthLimitations = map[string]int{ "ali-stable-diffusion-v1.5": 4000, "wanx-v1": 4000, "cogview-3": 833, + "step-1x-medium": 4000, } var ImageOriginModelName = map[string]string{ diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 7bc6cd54..cdcbc9d2 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -3,6 +3,7 @@ package ratio import ( "encoding/json" "fmt" + "math" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -22,63 +23,81 @@ const ( // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ // https://openai.com/pricing - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo": 5, // $0.01 / 1K tokens - "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.005 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens - "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens - "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens - "davinci-002": 1, // $0.002 / 1K tokens - "babbage-002": 0.2, // $0.0004 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-embedding-3-small": 0.01, - "text-embedding-3-large": 0.065, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image - "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo": 5, // $0.01 / 1K tokens + "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.005 / 1K tokens + "chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens + "gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens + "gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens + "gpt-4o-mini": 0.075, // $0.00015 / 1K tokens + "gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + // Audio billing will mix text and audio tokens, the unit price is different. + // Here records the cost of text, the cost multiplier of audio + // relative to text is in AudioRatio + "gpt-4o-audio-preview": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-12-17": 1.25, // $0.0025 / 1K tokens + "gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0025 / 1K tokens + "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens + "o1": 7.5, // $15.00 / 1M input tokens + "o1-2024-12-17": 7.5, + "o1-preview": 7.5, // $15.00 / 1M input tokens + "o1-preview-2024-09-12": 7.5, + "o1-mini": 1.5, // $3.00 / 1M input tokens + "o1-mini-2024-09-12": 1.5, + "davinci-002": 1, // $0.002 / 1K tokens + "babbage-002": 0.2, // $0.0004 / 1K tokens + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD, + "claude-3-5-haiku-20241022": 1.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, + "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, + "claude-3-5-sonnet-latest": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, @@ -98,11 +117,15 @@ var ModelRatio = map[string]float64{ "bge-large-en": 0.002 * RMB, "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro": 1, - "gemini-1.5-flash": 1, - "gemini-1.5-pro": 1, - "aqa": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro": 1, + "gemini-1.5-pro": 1, + "gemini-1.5-pro-001": 1, + "gemini-1.5-flash": 1, + "gemini-1.5-flash-001": 1, + "gemini-2.0-flash-exp": 1, + "gemini-2.0-flash-thinking-exp": 1, + "aqa": 1, // https://open.bigmodel.cn/pricing "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, @@ -114,27 +137,94 @@ var ModelRatio = map[string]float64{ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "cogview-3": 0.25 * RMB, // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens - "qwen-plus": 1.4286, // ¥0.02 / 1k tokens - "qwen-max": 1.4286, // ¥0.02 / 1k tokens - "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens - "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens - "ali-stable-diffusion-xl": 8, - "ali-stable-diffusion-v1.5": 8, - "wanx-v1": 8, - "SparkDesk": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens - "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens - "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens - "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens - "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - "ChatStd": 0.01 * RMB, - "ChatPro": 0.1 * RMB, + "qwen-turbo": 1.4286, // ¥0.02 / 1k tokens + "qwen-turbo-latest": 1.4286, + "qwen-plus": 1.4286, + "qwen-plus-latest": 1.4286, + "qwen-max": 1.4286, + "qwen-max-latest": 1.4286, + "qwen-max-longcontext": 1.4286, + "qwen-vl-max": 1.4286, + "qwen-vl-max-latest": 1.4286, + "qwen-vl-plus": 1.4286, + "qwen-vl-plus-latest": 1.4286, + "qwen-vl-ocr": 1.4286, + "qwen-vl-ocr-latest": 1.4286, + "qwen-audio-turbo": 1.4286, + "qwen-math-plus": 1.4286, + "qwen-math-plus-latest": 1.4286, + "qwen-math-turbo": 1.4286, + "qwen-math-turbo-latest": 1.4286, + "qwen-coder-plus": 1.4286, + "qwen-coder-plus-latest": 1.4286, + "qwen-coder-turbo": 1.4286, + "qwen-coder-turbo-latest": 1.4286, + "qwq-32b-preview": 1.4286, + "qwen2.5-72b-instruct": 1.4286, + "qwen2.5-32b-instruct": 1.4286, + "qwen2.5-14b-instruct": 1.4286, + "qwen2.5-7b-instruct": 1.4286, + "qwen2.5-3b-instruct": 1.4286, + "qwen2.5-1.5b-instruct": 1.4286, + "qwen2.5-0.5b-instruct": 1.4286, + "qwen2-72b-instruct": 1.4286, + "qwen2-57b-a14b-instruct": 1.4286, + "qwen2-7b-instruct": 1.4286, + "qwen2-1.5b-instruct": 1.4286, + "qwen2-0.5b-instruct": 1.4286, + "qwen1.5-110b-chat": 1.4286, + "qwen1.5-72b-chat": 1.4286, + "qwen1.5-32b-chat": 1.4286, + "qwen1.5-14b-chat": 1.4286, + "qwen1.5-7b-chat": 1.4286, + "qwen1.5-1.8b-chat": 1.4286, + "qwen1.5-0.5b-chat": 1.4286, + "qwen-72b-chat": 1.4286, + "qwen-14b-chat": 1.4286, + "qwen-7b-chat": 1.4286, + "qwen-1.8b-chat": 1.4286, + "qwen-1.8b-longcontext-chat": 1.4286, + "qwen2-vl-7b-instruct": 1.4286, + "qwen2-vl-2b-instruct": 1.4286, + "qwen-vl-v1": 1.4286, + "qwen-vl-chat-v1": 1.4286, + "qwen2-audio-instruct": 1.4286, + "qwen-audio-chat": 1.4286, + "qwen2.5-math-72b-instruct": 1.4286, + "qwen2.5-math-7b-instruct": 1.4286, + "qwen2.5-math-1.5b-instruct": 1.4286, + "qwen2-math-72b-instruct": 1.4286, + "qwen2-math-7b-instruct": 1.4286, + "qwen2-math-1.5b-instruct": 1.4286, + "qwen2.5-coder-32b-instruct": 1.4286, + "qwen2.5-coder-14b-instruct": 1.4286, + "qwen2.5-coder-7b-instruct": 1.4286, + "qwen2.5-coder-3b-instruct": 1.4286, + "qwen2.5-coder-1.5b-instruct": 1.4286, + "qwen2.5-coder-0.5b-instruct": 1.4286, + "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "text-embedding-v3": 0.05, + "text-embedding-v2": 0.05, + "text-embedding-async-v2": 0.05, + "text-embedding-async-v1": 0.05, + "ali-stable-diffusion-xl": 8.00, + "ali-stable-diffusion-v1.5": 8.00, + "wanx-v1": 8.00, + "SparkDesk": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "ChatStd": 0.01 * RMB, + "ChatPro": 0.1 * RMB, // https://platform.moonshot.cn/pricing "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, @@ -158,23 +248,33 @@ var ModelRatio = map[string]float64{ "mistral-embed": 0.1 / 1000 * USD, // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed "gemma-7b-it": 0.07 / 1000000 * USD, - "mixtral-8x7b-32768": 0.24 / 1000000 * USD, - "llama3-8b-8192": 0.05 / 1000000 * USD, - "llama3-70b-8192": 0.59 / 1000000 * USD, "gemma2-9b-it": 0.20 / 1000000 * USD, - "llama-3.1-405b-reasoning": 0.89 / 1000000 * USD, "llama-3.1-70b-versatile": 0.59 / 1000000 * USD, "llama-3.1-8b-instant": 0.05 / 1000000 * USD, + "llama-3.2-11b-text-preview": 0.05 / 1000000 * USD, + "llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD, + "llama-3.2-1b-preview": 0.05 / 1000000 * USD, + "llama-3.2-3b-preview": 0.05 / 1000000 * USD, + "llama-3.2-90b-text-preview": 0.59 / 1000000 * USD, + "llama-guard-3-8b": 0.05 / 1000000 * USD, + "llama3-70b-8192": 0.59 / 1000000 * USD, + "llama3-8b-8192": 0.05 / 1000000 * USD, "llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD, "llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD, + "mixtral-8x7b-32768": 0.24 / 1000000 * USD, // https://platform.lingyiwanwu.com/docs#-计费单元 "yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB, - // stepfun todo - "step-1v-32k": 0.024 * RMB, - "step-1-32k": 0.024 * RMB, - "step-1-200k": 0.15 * RMB, + // https://platform.stepfun.com/docs/pricing/details + "step-1-8k": 0.005 / 1000 * RMB, + "step-1-32k": 0.015 / 1000 * RMB, + "step-1-128k": 0.040 / 1000 * RMB, + "step-1-256k": 0.095 / 1000 * RMB, + "step-1-flash": 0.001 / 1000 * RMB, + "step-2-16k": 0.038 / 1000 * RMB, + "step-1v-8k": 0.005 / 1000 * RMB, + "step-1v-32k": 0.015 / 1000 * RMB, // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens @@ -192,16 +292,131 @@ var ModelRatio = map[string]float64{ "deepl-zh": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD, + // https://console.x.ai/ + "grok-beta": 5.0 / 1000 * USD, + // vertex imagen3 + // https://cloud.google.com/vertex-ai/generative-ai/pricing#imagen-models + "imagen-3.0-generate-001": 0.02 * USD, + // replicate charges based on the number of generated images + // https://replicate.com/pricing + "black-forest-labs/flux-1.1-pro": 0.04 * USD, + "black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD, + "black-forest-labs/flux-canny-dev": 0.025 * USD, + "black-forest-labs/flux-canny-pro": 0.05 * USD, + "black-forest-labs/flux-depth-dev": 0.025 * USD, + "black-forest-labs/flux-depth-pro": 0.05 * USD, + "black-forest-labs/flux-dev": 0.025 * USD, + "black-forest-labs/flux-dev-lora": 0.032 * USD, + "black-forest-labs/flux-fill-dev": 0.04 * USD, + "black-forest-labs/flux-fill-pro": 0.05 * USD, + "black-forest-labs/flux-pro": 0.055 * USD, + "black-forest-labs/flux-redux-dev": 0.025 * USD, + "black-forest-labs/flux-redux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell-lora": 0.02 * USD, + "ideogram-ai/ideogram-v2": 0.08 * USD, + "ideogram-ai/ideogram-v2-turbo": 0.05 * USD, + "recraft-ai/recraft-v3": 0.04 * USD, + "recraft-ai/recraft-v3-svg": 0.08 * USD, + "stability-ai/stable-diffusion-3": 0.035 * USD, + "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, + "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, + "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, + // replicate chat models + "ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD, + "ibm-granite/granite-3.0-2b-instruct": 0.030 * USD, + "ibm-granite/granite-3.0-8b-instruct": 0.050 * USD, + "ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD, + "meta/llama-2-13b": 0.100 * USD, + "meta/llama-2-13b-chat": 0.100 * USD, + "meta/llama-2-70b": 0.650 * USD, + "meta/llama-2-70b-chat": 0.650 * USD, + "meta/llama-2-7b": 0.050 * USD, + "meta/llama-2-7b-chat": 0.050 * USD, + "meta/meta-llama-3.1-405b-instruct": 9.500 * USD, + "meta/meta-llama-3-70b": 0.650 * USD, + "meta/meta-llama-3-70b-instruct": 0.650 * USD, + "meta/meta-llama-3-8b": 0.050 * USD, + "meta/meta-llama-3-8b-instruct": 0.050 * USD, + "mistralai/mistral-7b-instruct-v0.2": 0.050 * USD, + "mistralai/mistral-7b-v0.1": 0.050 * USD, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD, +} + +// AudioRatio represents the price ratio between audio tokens and text tokens +var AudioRatio = map[string]float64{ + "gpt-4o-audio-preview": 16, + "gpt-4o-audio-preview-2024-12-17": 16, + "gpt-4o-audio-preview-2024-10-01": 40, +} + +// GetAudioPromptRatio returns the audio prompt ratio for the given model. +func GetAudioPromptRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioRatio[actualModelName]; ok { + v = ratio + } else { + v = 16 + } + + return v +} + +// AudioCompletionRatio is the completion ratio for audio models. +var AudioCompletionRatio = map[string]float64{ + "whisper-1": 0, + "gpt-4o-audio-preview": 2, + "gpt-4o-audio-preview-2024-12-17": 2, + "gpt-4o-audio-preview-2024-10-01": 2, +} + +// GetAudioCompletionRatio returns the completion ratio for audio models. +func GetAudioCompletionRatio(actualModelName string) float64 { + var v float64 + if ratio, ok := AudioCompletionRatio[actualModelName]; ok { + v = ratio + } else { + v = 2 + } + + return v +} + +// AudioTokensPerSecond is the number of audio tokens per second for each model. +var AudioPromptTokensPerSecond = map[string]float64{ + // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens + "whisper-1": 1000 / 20, + // gpt-4o-audio series processes 10 tokens per second + "gpt-4o-audio-preview": 10, + "gpt-4o-audio-preview-2024-12-17": 10, + "gpt-4o-audio-preview-2024-10-01": 10, +} + +// GetAudioPromptTokensPerSecond returns the number of audio tokens per second +// for the given model. +func GetAudioPromptTokensPerSecond(actualModelName string) int { + var v float64 + if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok { + v = tokensPerSecond + } else { + v = 10 + } + + return int(math.Ceil(v)) } var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens } -var DefaultModelRatio map[string]float64 -var DefaultCompletionRatio map[string]float64 +var ( + DefaultModelRatio map[string]float64 + DefaultCompletionRatio map[string]float64 +) func init() { DefaultModelRatio = make(map[string]float64) @@ -254,19 +469,21 @@ func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } + model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := ModelRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[model]; ok { - return ratio - } - if ratio, ok := ModelRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultModelRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + ModelRatio, + DefaultModelRatio, + AudioRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + logger.SysError("model ratio not found: " + name) return 30 } @@ -289,18 +506,19 @@ func GetCompletionRatio(name string, channelType int) float64 { name = strings.TrimSuffix(name, "-internet") } model := fmt.Sprintf("%s(%d)", name, channelType) - if ratio, ok := CompletionRatio[model]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[model]; ok { - return ratio - } - if ratio, ok := CompletionRatio[name]; ok { - return ratio - } - if ratio, ok := DefaultCompletionRatio[name]; ok { - return ratio + + for _, targetName := range []string{model, name} { + for _, ratioMap := range []map[string]float64{ + CompletionRatio, + DefaultCompletionRatio, + AudioCompletionRatio, + } { + if ratio, ok := ratioMap[targetName]; ok { + return ratio + } + } } + if strings.HasPrefix(name, "gpt-3.5") { if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates @@ -313,16 +531,25 @@ func GetCompletionRatio(name string, channelType int) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") { - if strings.HasPrefix(name, "gpt-4o-mini") { + if strings.HasPrefix(name, "gpt-4o") { + if name == "gpt-4o-2024-05-13" { + return 3 + } return 4 } if strings.HasPrefix(name, "gpt-4-turbo") || - strings.HasPrefix(name, "gpt-4o") || strings.HasSuffix(name, "preview") { return 3 } return 2 } + // including o1/o1-preview/o1-mini + if strings.HasPrefix(name, "o1") { + return 4 + } + if name == "chatgpt-4o-latest" { + return 3 + } if strings.HasPrefix(name, "claude-3") { return 5 } @@ -338,6 +565,7 @@ func GetCompletionRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "deepseek-") { return 2 } + switch name { case "llama2-70b-4096": return 0.8 / 0.64 @@ -351,6 +579,37 @@ func GetCompletionRatio(name string, channelType int) float64 { return 3 case "command-r-plus": return 5 + case "grok-beta": + return 3 + // Replicate Models + // https://replicate.com/pricing + case "ibm-granite/granite-20b-code-instruct-8k": + return 5 + case "ibm-granite/granite-3.0-2b-instruct": + return 8.333333333333334 + case "ibm-granite/granite-3.0-8b-instruct", + "ibm-granite/granite-8b-code-instruct-128k": + return 5 + case "meta/llama-2-13b", + "meta/llama-2-13b-chat", + "meta/llama-2-7b", + "meta/llama-2-7b-chat", + "meta/meta-llama-3-8b", + "meta/meta-llama-3-8b-instruct": + return 5 + case "meta/llama-2-70b", + "meta/llama-2-70b-chat", + "meta/meta-llama-3-70b", + "meta/meta-llama-3-70b-instruct": + return 2.750 / 0.650 // ≈4.230769 + case "meta/meta-llama-3.1-405b-instruct": + return 1 + case "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mistral-7b-v0.1": + return 5 + case "mistralai/mixtral-8x7b-instruct-v0.1": + return 1.000 / 0.300 // ≈3.333333 } + return 1 } diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go new file mode 100644 index 00000000..435ab366 --- /dev/null +++ b/relay/channel/tencent/main.go @@ -0,0 +1,238 @@ +package tencent + +// import ( +// "bufio" +// "crypto/hmac" +// "crypto/sha1" +// "encoding/base64" +// "encoding/json" +// "github.com/pkg/errors" +// "fmt" +// "github.com/gin-gonic/gin" +// "github.com/songquanpeng/one-api/common" +// "github.com/songquanpeng/one-api/common/helper" +// "github.com/songquanpeng/one-api/common/logger" +// "github.com/songquanpeng/one-api/relay/channel/openai" +// "github.com/songquanpeng/one-api/relay/constant" +// "github.com/songquanpeng/one-api/relay/model" +// "io" +// "net/http" +// "sort" +// "strconv" +// "strings" +// ) + +// // https://cloud.tencent.com/document/product/1729/97732 + +// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { +// messages := make([]Message, 0, len(request.Messages)) +// for i := 0; i < len(request.Messages); i++ { +// message := request.Messages[i] +// if message.Role == "system" { +// messages = append(messages, Message{ +// Role: "user", +// Content: message.StringContent(), +// }) +// messages = append(messages, Message{ +// Role: "assistant", +// Content: "Okay", +// }) +// continue +// } +// messages = append(messages, Message{ +// Content: message.StringContent(), +// Role: message.Role, +// }) +// } +// stream := 0 +// if request.Stream { +// stream = 1 +// } +// return &ChatRequest{ +// Timestamp: helper.GetTimestamp(), +// Expired: helper.GetTimestamp() + 24*60*60, +// QueryID: helper.GetUUID(), +// Temperature: request.Temperature, +// TopP: request.TopP, +// Stream: stream, +// Messages: messages, +// } +// } + +// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { +// fullTextResponse := openai.TextResponse{ +// Object: "chat.completion", +// Created: helper.GetTimestamp(), +// Usage: response.Usage, +// } +// if len(response.Choices) > 0 { +// choice := openai.TextResponseChoice{ +// Index: 0, +// Message: model.Message{ +// Role: "assistant", +// Content: response.Choices[0].Messages.Content, +// }, +// FinishReason: response.Choices[0].FinishReason, +// } +// fullTextResponse.Choices = append(fullTextResponse.Choices, choice) +// } +// return &fullTextResponse +// } + +// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { +// response := openai.ChatCompletionsStreamResponse{ +// Object: "chat.completion.chunk", +// Created: helper.GetTimestamp(), +// Model: "tencent-hunyuan", +// } +// if len(TencentResponse.Choices) > 0 { +// var choice openai.ChatCompletionsStreamResponseChoice +// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content +// if TencentResponse.Choices[0].FinishReason == "stop" { +// choice.FinishReason = &constant.StopFinishReason +// } +// response.Choices = append(response.Choices, choice) +// } +// return &response +// } + +// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { +// var responseText string +// scanner := bufio.NewScanner(resp.Body) +// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { +// if atEOF && len(data) == 0 { +// return 0, nil, nil +// } +// if i := strings.Index(string(data), "\n"); i >= 0 { +// return i + 1, data[0:i], nil +// } +// if atEOF { +// return len(data), data, nil +// } +// return 0, nil, nil +// }) +// dataChan := make(chan string) +// stopChan := make(chan bool) +// go func() { +// for scanner.Scan() { +// data := scanner.Text() +// if len(data) < 5 { // ignore blank line or wrong format +// continue +// } +// if data[:5] != "data:" { +// continue +// } +// data = data[5:] +// dataChan <- data +// } +// stopChan <- true +// }() +// common.SetEventStreamHeaders(c) +// c.Stream(func(w io.Writer) bool { +// select { +// case data := <-dataChan: +// var TencentResponse ChatResponse +// err := json.Unmarshal([]byte(data), &TencentResponse) +// if err != nil { +// logger.SysError("error unmarshalling stream response: " + err.Error()) +// return true +// } +// response := streamResponseTencent2OpenAI(&TencentResponse) +// if len(response.Choices) != 0 { +// responseText += response.Choices[0].Delta.Content +// } +// jsonResponse, err := json.Marshal(response) +// if err != nil { +// logger.SysError("error marshalling stream response: " + err.Error()) +// return true +// } +// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) +// return true +// case <-stopChan: +// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) +// return false +// } +// }) +// err := resp.Body.Close() +// if err != nil { +// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" +// } +// return nil, responseText +// } + +// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// var TencentResponse ChatResponse +// responseBody, err := io.ReadAll(resp.Body) +// if err != nil { +// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +// } +// err = resp.Body.Close() +// if err != nil { +// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil +// } +// err = json.Unmarshal(responseBody, &TencentResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil +// } +// if TencentResponse.Error.Code != 0 { +// return &model.ErrorWithStatusCode{ +// Error: model.Error{ +// Message: TencentResponse.Error.Message, +// Code: TencentResponse.Error.Code, +// }, +// StatusCode: resp.StatusCode, +// }, nil +// } +// fullTextResponse := responseTencent2OpenAI(&TencentResponse) +// fullTextResponse.Model = "hunyuan" +// jsonResponse, err := json.Marshal(fullTextResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil +// } +// c.Writer.Header().Set("Content-Type", "application/json") +// c.Writer.WriteHeader(resp.StatusCode) +// _, err = c.Writer.Write(jsonResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil +// } +// return nil, &fullTextResponse.Usage +// } + +// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { +// parts := strings.Split(config, "|") +// if len(parts) != 3 { +// err = errors.New("invalid tencent config") +// return +// } +// appId, err = strconv.ParseInt(parts[0], 10, 64) +// secretId = parts[1] +// secretKey = parts[2] +// return +// } + +// func GetSign(req ChatRequest, secretKey string) string { +// params := make([]string, 0) +// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) +// params = append(params, "secret_id="+req.SecretId) +// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) +// params = append(params, "query_id="+req.QueryID) +// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) +// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) +// params = append(params, "stream="+strconv.Itoa(req.Stream)) +// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + +// var messageStr string +// for _, msg := range req.Messages { +// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) +// } +// messageStr = strings.TrimSuffix(messageStr, ",") +// params = append(params, "messages=["+messageStr+"]") + +// sort.Strings(params) +// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") +// mac := hmac.New(sha1.New, []byte(secretKey)) +// signURL := url +// mac.Write([]byte(signURL)) +// sign := mac.Sum([]byte(nil)) +// return base64.StdEncoding.EncodeToString(sign) +// } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index a261cff8..f54d0e30 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -46,5 +46,7 @@ const ( VertextAI Proxy SiliconFlow + XAI + Replicate Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index fae3357f..8839b30a 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { apiType = apitype.DeepL case VertextAI: apiType = apitype.VertexAI + case Replicate: + apiType = apitype.Replicate case Proxy: apiType = apitype.Proxy } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 8727faea..8e271f4e 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -45,7 +45,9 @@ var ChannelBaseURLs = []string{ "https://api.novita.ai/v3/openai", // 41 "", // 42 "", // 43 - "https://api.siliconflow.cn", // 44 + "https://api.siliconflow.cn", // 44 + "https://api.x.ai", // 45 + "https://api.replicate.com/v1/models/", // 46 } func init() { diff --git a/relay/client/init.go b/relay/client/init.go new file mode 100644 index 00000000..73108700 --- /dev/null +++ b/relay/client/init.go @@ -0,0 +1,36 @@ +package client + +import ( + "net/http" + "os" + "time" + + gutils "github.com/Laisky/go-utils/v4" + "github.com/songquanpeng/one-api/common/config" +) + +var HTTPClient *http.Client +var ImpatientHTTPClient *http.Client + +func init() { + var opts []gutils.HTTPClientOptFunc + + timeout := time.Duration(max(config.IdleTimeout, 30)) * time.Second + opts = append(opts, gutils.WithHTTPClientTimeout(timeout)) + if os.Getenv("RELAY_PROXY") != "" { + opts = append(opts, gutils.WithHTTPClientProxy(os.Getenv("RELAY_PROXY"))) + } + + var err error + HTTPClient, err = gutils.NewHTTPClient(opts...) + if err != nil { + panic(err) + } + + ImpatientHTTPClient, err = gutils.NewHTTPClient( + gutils.WithHTTPClientTimeout(5 * time.Second), + ) + if err != nil { + panic(err) + } +} diff --git a/relay/constant/role/define.go b/relay/constant/role/define.go index 972488c5..5097c97e 100644 --- a/relay/constant/role/define.go +++ b/relay/constant/role/define.go @@ -1,5 +1,6 @@ package role const ( + System = "system" Assistant = "assistant" ) diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 83040662..247db452 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -5,21 +5,23 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" + "mime/multipart" "net/http" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" + "github.com/songquanpeng/one-api/relay/billing/ratio" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" @@ -27,6 +29,35 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) +type commonAudioRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` +} + +func countAudioTokens(c *gin.Context) (int, error) { + body, err := common.GetRequestBody(c) + if err != nil { + return 0, errors.WithStack(err) + } + + reqBody := new(commonAudioRequest) + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + if err = c.ShouldBind(reqBody); err != nil { + return 0, errors.WithStack(err) + } + + reqFp, err := reqBody.File.Open() + if err != nil { + return 0, errors.WithStack(err) + } + defer reqFp.Close() + + ctxMeta := meta.GetByContext(c) + + return helper.GetAudioTokens(c.Request.Context(), + reqFp, + ratio.GetAudioPromptTokensPerSecond(ctxMeta.ActualModelName)) +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) @@ -36,7 +67,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus channelType := c.GetInt(ctxkey.Channel) channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) - group := c.GetString(ctxkey.Group) + // group := c.GetString(ctxkey.Group) tokenName := c.GetString(ctxkey.TokenName) var ttsRequest openai.TextToSpeechRequest @@ -55,7 +86,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } modelRatio := billingratio.GetModelRatio(audioModel, channelType) - groupRatio := billingratio.GetGroupRatio(group) + // groupRatio := billingratio.GetGroupRatio(group) + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 @@ -63,9 +95,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota + case relaymode.AudioTranscription, + relaymode.AudioTranslation: + audioTokens, err := countAudioTokens(c) + if err != nil { + return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) + } + + preConsumedQuota = int64(float64(audioTokens) * ratio) + quota = preConsumedQuota default: - preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError) } + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -110,16 +152,9 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus }() // map model name - modelMapping := c.GetString(ctxkey.ModelMapping) - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[audioModel] != "" { - audioModel = modelMap[audioModel] - } + modelMapping := c.GetStringMapString(ctxkey.ModelMapping) + if modelMapping != nil && modelMapping[audioModel] != "" { + audioModel = modelMapping[audioModel] } baseURL := channeltype.ChannelBaseURLs[channelType] @@ -146,7 +181,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") + // responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -179,47 +214,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != relaymode.AudioSpeech { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } + // https://github.com/Laisky/one-api/pull/21 + // Commenting out the following code because Whisper's transcription + // only charges for the length of the input audio, not for the output. + // ------------------------------------- + // if relayMode != relaymode.AudioSpeech { + // responseBody, err := io.ReadAll(resp.Body) + // if err != nil { + // return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + // } + // err = resp.Body.Close() + // if err != nil { + // return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + // } - var openAIErr openai.SlimTextResponse - if err = json.Unmarshal(responseBody, &openAIErr); err == nil { - if openAIErr.Error.Message != "" { - return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) - } - } + // var openAIErr openai.SlimTextResponse + // if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + // if openAIErr.Error.Message != "" { + // return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + // } + // } + + // var text string + // switch responseFormat { + // case "json": + // text, err = getTextFromJSON(responseBody) + // case "text": + // text, err = getTextFromText(responseBody) + // case "srt": + // text, err = getTextFromSRT(responseBody) + // case "verbose_json": + // text, err = getTextFromVerboseJSON(responseBody) + // case "vtt": + // text, err = getTextFromVTT(responseBody) + // default: + // return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + // } + // if err != nil { + // return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + // } + // quota = int64(openai.CountTokenText(text, audioModel)) + // resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // } - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - default: - return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) - } - if err != nil { - return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) - } - quota = int64(openai.CountTokenText(text, audioModel)) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } if resp.StatusCode != http.StatusOK { return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { @@ -249,8 +290,9 @@ func getTextFromVTT(body []byte) (string, error) { func getTextFromVerboseJSON(body []byte) (string, error) { var whisperResponse openai.WhisperVerboseJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { - return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + return "", errors.Wrap(err, "unmarshal_response_body_failed") } + return whisperResponse.Text, nil } @@ -282,7 +324,7 @@ func getTextFromText(body []byte) (string, error) { func getTextFromJSON(body []byte) (string, error) { var whisperResponse openai.WhisperJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { - return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + return "", errors.Wrap(err, "unmarshal_response_body_failed") } return whisperResponse.Text, nil } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 87d22f13..55fcc11e 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -2,13 +2,13 @@ package controller import ( "context" - "errors" "fmt" "math" "net/http" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" @@ -16,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/constant/role" "github.com/songquanpeng/one-api/relay/controller/validator" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -41,10 +42,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { +func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: - return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model) case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case relaymode.Moderations: @@ -90,12 +91,12 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) (quota int64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - var quota int64 + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens @@ -118,10 +119,16 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + var extraLog string + if systemPromptReset { + extraLog = " (注意系统提示词已被重置)" + } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f%s", modelRatio, groupRatio, completionRatio, extraLog) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) + + return quota } func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { @@ -142,15 +149,41 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { } return true } - if resp.StatusCode != http.StatusOK { + if resp.StatusCode != http.StatusOK && + // replicate return 201 to create a task + resp.StatusCode != http.StatusCreated { return true } if meta.ChannelType == channeltype.DeepL { // skip stream check for deepl return false } - if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + + if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") && + // Even if stream mode is enabled, replicate will first return a task info in JSON format, + // requiring the client to request the stream endpoint in the task info + meta.ChannelType != channeltype.Replicate { return true } return false } + +func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) { + if prompt == "" { + return false + } + if len(request.Messages) == 0 { + return false + } + if request.Messages[0].Role == role.System { + request.Messages[0].Content = prompt + logger.Infof(ctx, "rewrite system prompt") + return true + } + request.Messages = append([]relaymodel.Message{{ + Role: role.System, + Content: prompt, + }}, request.Messages...) + logger.Infof(ctx, "add system prompt") + return true +} diff --git a/relay/controller/image.go b/relay/controller/image.go index 1e06e858..613fde8b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -4,29 +4,31 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" + "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" - "github.com/songquanpeng/one-api/relay/meta" + metalib "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { +func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return nil, err + return nil, errors.WithStack(err) } if imageRequest.N == 0 { imageRequest.N = 1 @@ -65,7 +67,7 @@ func getImageSizeRatio(model string, size string) float64 { return 1 } -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { +func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) @@ -104,7 +106,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := meta.GetByContext(c) + meta := metalib.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -116,6 +118,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus meta.OriginModelName = imageRequest.Model imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model + metalib.Set2Context(c, meta) // model validation bizErr := validateImageRequest(imageRequest, meta) @@ -134,7 +137,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader - if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body + if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" && + isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) @@ -150,13 +154,23 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } adaptor.Init(meta) + // these adaptors need to convert the request switch meta.ChannelType { - case channeltype.Ali: - fallthrough - case channeltype.Baidu: - fallthrough - case channeltype.Zhipu: - finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + case channeltype.Zhipu, + channeltype.Ali, + channeltype.VertextAI, + channeltype.Baidu: + finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + case channeltype.Replicate: + finalRequest, err := replicate.ConvertImageRequest(c, imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) } @@ -168,11 +182,20 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) - groupRatio := billingratio.GetGroupRatio(meta.Group) + // groupRatio := billingratio.GetGroupRatio(meta.Group) + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) + ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -186,7 +209,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } defer func(ctx context.Context) { - if resp != nil && resp.StatusCode != http.StatusOK { + if resp != nil && + resp.StatusCode != http.StatusCreated && // replicate returns 201 + resp.StatusCode != http.StatusOK { return } @@ -198,13 +223,23 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } - if quota != 0 { + if quota >= 0 { tokenName := c.GetString(ctxkey.TokenName) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt(ctxkey.ChannelId) model.UpdateChannelUsedQuota(channelId, quota) + + // also update user request cost + docu := model.NewUserRequestCost( + c.GetInt(ctxkey.Id), + c.GetString(ctxkey.RequestId), + quota, + ) + if err = docu.Insert(); err != nil { + logger.Errorf(c, "insert user request cost failed: %+v", err) + } } }(c.Request.Context()) diff --git a/relay/controller/text.go b/relay/controller/text.go index 52ee9949..0dc06019 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -2,13 +2,18 @@ package controller import ( "bytes" + "context" "encoding/json" - "fmt" "io" "net/http" + "time" "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -17,10 +22,10 @@ import ( billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" - "github.com/songquanpeng/one-api/relay/model" + relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) // get & validate textRequest @@ -35,12 +40,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { meta.OriginModelName = textRequest.Model textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model + // set system prompt if not empty + systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt) // get model ratio & group ratio modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) - groupRatio := billingratio.GetGroupRatio(meta.Group) + // groupRatio := billingratio.GetGroupRatio(meta.Group) + groupRatio := c.GetFloat64(ctxkey.ChannelRatio) + ratio := modelRatio * groupRatio // pre-consume quota - promptTokens := getPromptTokens(textRequest, meta.Mode) + promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode) meta.PromptTokens = promptTokens preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) if bizErr != nil { @@ -50,7 +59,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { - return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } adaptor.Init(meta) @@ -60,6 +69,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) } + // for debug + requestBodyBytes, _ := io.ReadAll(requestBody) + requestBody = bytes.NewBuffer(requestBodyBytes) + // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { @@ -78,14 +91,38 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } + // post-consume quota - go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) + quotaId := c.GetInt(ctxkey.Id) + requestId := c.GetString(ctxkey.RequestId) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + quota := postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) + + // also update user request cost + if quota != 0 { + docu := model.NewUserRequestCost( + quotaId, + requestId, + quota, + ) + if err = docu.Insert(); err != nil { + logger.Errorf(ctx, "insert user request cost failed: %+v", err) + } + } + }() + return nil } -func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { - if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan { - // no need to convert request for openai +func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) { + if !config.EnforceIncludeUsage && + meta.APIType == apitype.OpenAI && + meta.OriginModelName == meta.ActualModelName && + meta.ChannelType != channeltype.OpenAI && // openai also need to convert request + meta.ChannelType != channeltype.Baichuan { return c.Request.Body, nil } diff --git a/relay/controller/validator/validation.go b/relay/controller/validator/validation.go index 8ff520b8..3d5d755f 100644 --- a/relay/controller/validator/validation.go +++ b/relay/controller/validator/validation.go @@ -1,10 +1,11 @@ package validator import ( - "errors" + "math" + + "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "math" ) func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error { diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index b1761e9a..b4bcf687 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,12 +1,13 @@ package meta import ( + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" - "strings" ) type Meta struct { @@ -30,9 +31,15 @@ type Meta struct { ActualModelName string RequestURLPath string PromptTokens int // only for DoResponse + ChannelRatio float64 + SystemPrompt string } func GetByContext(c *gin.Context) *Meta { + if v, ok := c.Get(ctxkey.Meta); ok { + return v.(*Meta) + } + meta := Meta{ Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt(ctxkey.Channel), @@ -46,6 +53,8 @@ func GetByContext(c *gin.Context) *Meta { BaseURL: c.GetString(ctxkey.BaseURL), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), RequestURLPath: c.Request.URL.String(), + ChannelRatio: c.GetFloat64(ctxkey.ChannelRatio), // add by Laisky + SystemPrompt: c.GetString(ctxkey.SystemPrompt), } cfg, ok := c.Get(ctxkey.Config) if ok { @@ -55,5 +64,11 @@ func GetByContext(c *gin.Context) *Meta { meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } meta.APIType = channeltype.ToAPIType(meta.ChannelType) + + Set2Context(c, &meta) return &meta } + +func Set2Context(c *gin.Context, meta *Meta) { + c.Set(ctxkey.Meta, meta) +} diff --git a/relay/model/constant.go b/relay/model/constant.go index f6cf1924..c9d6d645 100644 --- a/relay/model/constant.go +++ b/relay/model/constant.go @@ -1,6 +1,7 @@ package model const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" + ContentTypeText = "text" + ContentTypeImageURL = "image_url" + ContentTypeInputAudio = "input_audio" ) diff --git a/relay/model/general.go b/relay/model/general.go index c34c1c2d..5354694c 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,35 +1,71 @@ package model type ResponseFormat struct { - Type string `json:"type,omitempty"` + Type string `json:"type,omitempty"` + JsonSchema *JSONSchema `json:"json_schema,omitempty"` +} + +type JSONSchema struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Schema map[string]interface{} `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +type Audio struct { + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` +} + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` } type GeneralOpenAIRequest struct { - Messages []Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` + // https://platform.openai.com/docs/api-reference/chat/create + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata any `json:"metadata,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + // Modalities currently the model only programmatically allows modalities = [“text”, “audio”] + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + Audio *Audio `json:"audio,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` Stop any `json:"stop,omitempty"` Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` + User string `json:"user,omitempty"` FunctionCall any `json:"function_call,omitempty"` Functions any `json:"functions,omitempty"` - User string `json:"user,omitempty"` - Prompt any `json:"prompt,omitempty"` - Input any `json:"input,omitempty"` - EncodingFormat string `json:"encoding_format,omitempty"` - Dimensions int `json:"dimensions,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` + // https://platform.openai.com/docs/api-reference/embeddings/create + Input any `json:"input,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + // https://platform.openai.com/docs/api-reference/images/create + Prompt any `json:"prompt,omitempty"` + Quality *string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style *string `json:"style,omitempty"` + // Others + Instruction string `json:"instruction,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/image.go b/relay/model/image.go index bab84256..ec3e7691 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -1,12 +1,12 @@ package model type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model" form:"model"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + N int `json:"n,omitempty" form:"n"` + Size string `json:"size,omitempty" form:"size"` + Quality string `json:"quality,omitempty" form:"quality"` + ResponseFormat string `json:"response_format,omitempty" form:"response_format"` + Style string `json:"style,omitempty" form:"style"` + User string `json:"user,omitempty" form:"user"` } diff --git a/relay/model/message.go b/relay/model/message.go index b908f989..48ddb3ad 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,11 +1,26 @@ package model +import ( + "context" + + "github.com/songquanpeng/one-api/common/logger" +) + type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` + Role string `json:"role,omitempty"` + // Content is a string or a list of objects + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` + Audio *messageAudio `json:"audio,omitempty"` +} + +type messageAudio struct { + Id string `json:"id"` + Data string `json:"data,omitempty"` + ExpiredAt int `json:"expired_at,omitempty"` + Transcript string `json:"transcript,omitempty"` } func (m Message) IsStringContent() bool { @@ -26,6 +41,7 @@ func (m Message) StringContent() string { if !ok { continue } + if contentMap["type"] == ContentTypeText { if subStr, ok := contentMap["text"].(string); ok { contentStr += subStr @@ -34,6 +50,7 @@ func (m Message) StringContent() string { } return contentStr } + return "" } @@ -47,6 +64,7 @@ func (m Message) ParseContent() []MessageContent { }) return contentList } + anyList, ok := m.Content.([]any) if ok { for _, contentItem := range anyList { @@ -71,8 +89,21 @@ func (m Message) ParseContent() []MessageContent { }, }) } + case ContentTypeInputAudio: + if subObj, ok := contentMap["input_audio"].(map[string]any); ok { + contentList = append(contentList, MessageContent{ + Type: ContentTypeInputAudio, + InputAudio: &InputAudio{ + Data: subObj["data"].(string), + Format: subObj["format"].(string), + }, + }) + } + default: + logger.Warnf(context.TODO(), "unknown content type: %s", contentMap["type"]) } } + return contentList } return nil @@ -84,7 +115,18 @@ type ImageURL struct { } type MessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` + // Type should be one of the following: text/input_audio + Type string `json:"type,omitempty"` + Text string `json:"text"` + ImageURL *ImageURL `json:"image_url,omitempty"` + InputAudio *InputAudio `json:"input_audio,omitempty"` +} + +type InputAudio struct { + // Data is the base64 encoded audio data + Data string `json:"data" binding:"required"` + // Format is the audio format, should be one of the + // following: mp3/mp4/mpeg/mpga/m4a/wav/webm/pcm16. + // When stream=true, format should be pcm16 + Format string `json:"format"` } diff --git a/relay/model/misc.go b/relay/model/misc.go index 163bc398..62c3fe6f 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -4,6 +4,12 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + // PromptTokensDetails may be empty for some models + PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + // CompletionTokensDetails may be empty for some models + CompletionTokensDetails usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` + ServiceTier string `gorm:"-" json:"service_tier,omitempty"` + SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` } type Error struct { @@ -17,3 +23,20 @@ type ErrorWithStatusCode struct { Error StatusCode int `json:"status_code"` } + +type usagePromptTokensDetails struct { + CachedTokens int `json:"cached_tokens"` + AudioTokens int `json:"audio_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` +} + +type usageCompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index aa771205..12acb940 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -13,4 +13,5 @@ const ( AudioTranslation // Proxy is a special relay mode for proxying requests to custom upstream Proxy + ImagesEdits ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 2cde5b85..35a0535e 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,8 +24,11 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = ImagesEdits } else if strings.HasPrefix(path, "/v1/oneapi/proxy") { relayMode = Proxy } + return relayMode } diff --git a/router/api.go b/router/api.go index d2ada4eb..d5a6bf4f 100644 --- a/router/api.go +++ b/router/api.go @@ -21,8 +21,10 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) + apiRouter.GET("/user/get-by-token", middleware.TokenAuth(), controller.GetSelfByToken) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) @@ -92,6 +94,11 @@ func SetApiRouter(router *gin.Engine) { tokenRoute.POST("/", controller.AddToken) tokenRoute.PUT("/", controller.UpdateToken) tokenRoute.DELETE("/:id", controller.DeleteToken) + apiRouter.POST("/token/consume", middleware.TokenAuth(), controller.ConsumeToken) + } + costRoute := apiRouter.Group("/cost") + { + costRoute.GET("/request/:request_id", controller.GetRequestCost) } redemptionRoute := apiRouter.Group("/redemption") redemptionRoute.Use(middleware.AdminAuth()) diff --git a/router/relay.go b/router/relay.go index 094ea5fb..554a64f4 100644 --- a/router/relay.go +++ b/router/relay.go @@ -9,6 +9,7 @@ import ( func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) + router.Use(middleware.GzipDecodeMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) @@ -24,7 +25,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) diff --git a/router/web.go b/router/web.go index 3c9b4643..ebfc2ae1 100644 --- a/router/web.go +++ b/router/web.go @@ -3,6 +3,9 @@ package router import ( "embed" "fmt" + "net/http" + "strings" + "github.com/gin-contrib/gzip" "github.com/gin-contrib/static" "github.com/gin-gonic/gin" @@ -10,8 +13,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" - "net/http" - "strings" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS) { diff --git a/web/air/src/components/TokensTable.js b/web/air/src/components/TokensTable.js index 0853ddfb..48836c85 100644 --- a/web/air/src/components/TokensTable.js +++ b/web/air/src/components/TokensTable.js @@ -11,12 +11,14 @@ import EditToken from '../pages/Token/EditToken'; const COPY_OPTIONS = [ { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' }, ]; const OPEN_LINK_OPTIONS = [ { key: 'ama', text: 'ChatGPT Web & Midjourney', value: 'ama' }, - { key: 'opencat', text: 'OpenCat', value: 'opencat' } + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, + { key: 'lobechat', text: 'LobeChat', value: 'lobechat' } ]; function renderTimestamp(timestamp) { @@ -60,7 +62,12 @@ const TokensTable = () => { onOpenLink('next-mj'); } }, - { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' } + { node: 'item', key: 'opencat', name: 'OpenCat', value: 'opencat' }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } + } ]; const columns = [ @@ -177,6 +184,11 @@ const TokensTable = () => { node: 'item', key: 'opencat', name: 'OpenCat', onClick: () => { onOpenLink('opencat', record.key); } + }, + { + node: 'item', key: 'lobechat', name: 'LobeChat', onClick: () => { + onOpenLink('lobechat'); + } } ] } @@ -382,6 +394,9 @@ const TokensTable = () => { case 'next-mj': url = mjLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; break; + case 'lobechat': + url = chatLink + `/?settings={"keyVaults":{"openai":{"apiKey":"sk-${key}","baseURL":"${serverAddress}/v1"}}}`; + break; default: if (!chatLink) { showError('管理员未设置聊天链接'); diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index 04fe94f1..e7b25399 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -30,6 +30,8 @@ export const CHANNEL_OPTIONS = [ { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, + { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/air/src/pages/Channel/EditChannel.js b/web/air/src/pages/Channel/EditChannel.js index 73fd2da2..4a810830 100644 --- a/web/air/src/pages/Channel/EditChannel.js +++ b/web/air/src/pages/Channel/EditChannel.js @@ -43,6 +43,7 @@ const EditChannel = (props) => { base_url: '', other: '', model_mapping: '', + system_prompt: '', models: [], auto_ban: 1, groups: ['default'] @@ -63,7 +64,7 @@ const EditChannel = (props) => { let localModels = []; switch (value) { case 14: - localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-sonnet-20240620"]; + localModels = ["claude-instant-1.2", "claude-2", "claude-2.0", "claude-2.1", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-3-5-haiku-20241022", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022"]; break; case 11: localModels = ['PaLM-2']; @@ -78,7 +79,7 @@ const EditChannel = (props) => { localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: - localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.5', 'SparkDesk-v4.0']; + localModels = ['SparkDesk', 'SparkDesk-v1.1', 'SparkDesk-v2.1', 'SparkDesk-v3.1', 'SparkDesk-v3.1-128K', 'SparkDesk-v3.5', 'SparkDesk-v3.5-32K', 'SparkDesk-v4.0']; break; case 19: localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; @@ -304,163 +305,163 @@ const EditChannel = (props) => { width={isMobile() ? '100%' : 600} > -
+
类型:
{ - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> -
- 默认 API 版本: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 3 && ( + <> +
+ + 注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 + model + 参数替换为你的部署名称(模型名称中的点会被剔除),图片演示。 + + }> + +
+
+ AZURE_OPENAI_ENDPOINT: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> +
+ 默认 API 版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 8 && ( - <> -
- Base URL: -
- { - handleInputChange('base_url', value) - }} - value={inputs.base_url} - autoComplete='new-password' - /> - - ) + inputs.type === 8 && ( + <> +
+ Base URL: +
+ { + handleInputChange('base_url', value) + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + ) } -
+
名称:
{ - handleInputChange('name', value) - }} - value={inputs.name} - autoComplete='new-password' + required + name='name' + placeholder={'请为渠道命名'} + onChange={value => { + handleInputChange('name', value) + }} + value={inputs.name} + autoComplete='new-password' /> -
+
分组:
{ - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 18 && ( + <> +
+ 模型版本: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } { - inputs.type === 21 && ( - <> -
- 知识库 ID: -
- { - handleInputChange('other', value) - }} - value={inputs.other} - autoComplete='new-password' - /> - - ) + inputs.type === 21 && ( + <> +
+ 知识库 ID: +
+ { + handleInputChange('other', value) + }} + value={inputs.other} + autoComplete='new-password' + /> + + ) } -
+
模型:
填入 - } - placeholder='输入自定义模型名称' - value={customModel} - onChange={(value) => { - setCustomModel(value.trim()); - }} + addonAfter={ + + } + placeholder='输入自定义模型名称' + value={customModel} + onChange={(value) => { + setCustomModel(value.trim()); + }} />
-
+
模型重定向: