mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
80 Commits
v0.3.1.0-a
...
v0.2.10-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0ab39e446 | ||
|
|
e4c01cb9ae | ||
|
|
d43a65bc52 | ||
|
|
d62dd4d9a2 | ||
|
|
f113e1874e | ||
|
|
c6d5245c5c | ||
|
|
d1ea2d2d0a | ||
|
|
e0f780185a | ||
|
|
c47e1dc6fe | ||
|
|
27b8495698 | ||
|
|
333849429b | ||
|
|
9b14c5da63 | ||
|
|
a9e3555cac | ||
|
|
8d83630d85 | ||
|
|
a60f209c85 | ||
|
|
208bc5e794 | ||
|
|
400b2b0ed0 | ||
|
|
27b034674d | ||
|
|
fefe5913e9 | ||
|
|
142d5fb209 | ||
|
|
f6ccd402e2 | ||
|
|
1c371300ab | ||
|
|
918690701d | ||
|
|
f0008e95fa | ||
|
|
7a249b206d | ||
|
|
0cc7f5cca6 | ||
|
|
ed86ec8b59 | ||
|
|
895ee09b33 | ||
|
|
61006bed9e | ||
|
|
58591b8c2a | ||
|
|
69b9950aa0 | ||
|
|
bcf1b160d1 | ||
|
|
1e0e40aa69 | ||
|
|
74392efbed | ||
|
|
cc020d6a40 | ||
|
|
daa1741aed | ||
|
|
a25bcaa58f | ||
|
|
d34b601dae | ||
|
|
9932962320 | ||
|
|
00d6cda9ed | ||
|
|
5ffb520363 | ||
|
|
e0f80cdb8f | ||
|
|
7a7a923504 | ||
|
|
4f6c171a08 | ||
|
|
310f8c247e | ||
|
|
811019bf5c | ||
|
|
0ed6600437 | ||
|
|
1fe7f14d57 | ||
|
|
a7bafec1bf | ||
|
|
ed951b3974 | ||
|
|
c74e43b8fd | ||
|
|
03bd9b0cc4 | ||
|
|
149902bd8a | ||
|
|
cd3ed22045 | ||
|
|
2329d387ca | ||
|
|
7d18a8e2a9 | ||
|
|
80af3718d0 | ||
|
|
77ea6bec46 | ||
|
|
c0ab8ae953 | ||
|
|
923c2dee32 | ||
|
|
ea17a46d8e | ||
|
|
bfe9e5d25a | ||
|
|
831ff47254 | ||
|
|
11eaba6b5d | ||
|
|
c2e4ec25c8 | ||
|
|
8537f10412 | ||
|
|
71d60eeef7 | ||
|
|
247ae0988f | ||
|
|
0907fa6994 | ||
|
|
9855343aa8 | ||
|
|
bd50fde268 | ||
|
|
4267de5642 | ||
|
|
8b55116563 | ||
|
|
f35e63e3f3 | ||
|
|
17c409de23 | ||
|
|
e4753e7411 | ||
|
|
9adefa80b9 | ||
|
|
4ce2381182 | ||
|
|
62afc21ea5 | ||
|
|
7ddb7c586d |
12
.github/FUNDING.yml
vendored
12
.github/FUNDING.yml
vendored
@@ -1,12 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,5 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
- name: 交流社区
|
||||
url: https://linux.do
|
||||
about: 项目交流社区
|
||||
|
||||
5
.github/workflows/docker-image-amd64.yml
vendored
5
.github/workflows/docker-image-amd64.yml
vendored
@@ -1,9 +1,6 @@
|
||||
name: Publish Docker image (amd64)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
@@ -42,7 +39,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
2
.github/workflows/docker-image-arm64.yml
vendored
2
.github/workflows/docker-image-arm64.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
pengzhile/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
FROM node:16 as builder
|
||||
FROM node:20 as builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
RUN npm install
|
||||
RUN yarn install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
FROM golang AS builder2
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ all: build-frontend start-backend
|
||||
|
||||
build-frontend:
|
||||
@echo "Building frontend..."
|
||||
@cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
@cd $(FRONTEND_DIR) && yarn install --network-timeout 1000000 && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) yarn build
|
||||
|
||||
start-backend:
|
||||
@echo "Starting backend dev server..."
|
||||
@@ -48,11 +48,10 @@
|
||||
4. Telegram Bot 名称是bot username 去掉@后的字符串
|
||||
13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md)
|
||||
14. 支持Rerank模型,目前仅兼容Cohere和Jina,可接入Dify,[对接文档](Rerank.md)
|
||||
15. **[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API,支持Azure渠道。
|
||||
|
||||
## 模型支持
|
||||
此版本额外支持以下模型:
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*)
|
||||
1. 第三方模型 **gps** (gpt-4-gizmo-*, g-*)
|
||||
2. 智谱glm-4v,glm-4v识图
|
||||
3. Anthropic Claude 3
|
||||
4. [Ollama](https://github.com/ollama/ollama?tab=readme-ov-file),添加渠道时,密钥可以随便填写,默认的请求地址是[http://localhost:11434](http://localhost:11434),如果需要修改请在渠道中修改
|
||||
@@ -64,7 +63,7 @@
|
||||
10. Dify
|
||||
11. Vertex AI,目前兼容Claude,Gemini,Llama3.1
|
||||
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
您可以在渠道中添加自定义模型gpt-4-gizmo-*或g-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
|
||||
|
||||
## 比原版One API多出的配置
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`。
|
||||
@@ -134,9 +133,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||

|
||||

|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="200">
|
||||
|
||||
## 相关项目
|
||||
- [One API](https://github.com/songquanpeng/one-api):原版项目
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy):Midjourney接口支持
|
||||
|
||||
@@ -9,9 +9,20 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Pay Settings
|
||||
|
||||
var StripeApiSecret = ""
|
||||
var StripeWebhookSecret = ""
|
||||
var StripePriceId = ""
|
||||
var PaymentEnabled = false
|
||||
var StripeUnitPrice = 8.0
|
||||
var MinTopUp = 5
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "New API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var OutProxyUrl = ""
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
@@ -41,10 +52,12 @@ var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var LinuxDoOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
var UserSelfDeletionEnabled = false
|
||||
|
||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||
@@ -75,6 +88,10 @@ var SMTPToken = ""
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
|
||||
var LinuxDoClientId = ""
|
||||
var LinuxDoClientSecret = ""
|
||||
var LinuxDoMinLevel = 0
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
@@ -183,6 +200,12 @@ const (
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
@@ -222,7 +245,6 @@ const (
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
@@ -271,5 +293,4 @@ var ChannelBaseURLs = []string{
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
}
|
||||
|
||||
84
common/hash.go
Normal file
84
common/hash.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Sha256Raw(data string) []byte {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(data))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data string) string {
|
||||
return hex.EncodeToString(Sha1Raw([]byte(data)))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
|
||||
func RandomBytes(length int) []byte {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
b := make([]byte, length)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func RandomString(length int) string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomHex(length int) string {
|
||||
const chars = "abcdef0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomNumber(length int) string {
|
||||
const chars = "0123456789"
|
||||
result := make([]byte, length)
|
||||
randomBytes := RandomBytes(length)
|
||||
for i := 0; i < length; i++ {
|
||||
result[i] = chars[randomBytes[i]%byte(len(chars))]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomUUID() string {
|
||||
all := RandomHex(32)
|
||||
return all[:8] + "-" + all[8:12] + "-" + all[12:16] + "-" + all[16:20] + "-" + all[20:]
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -31,9 +30,24 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := ProxiedHttpHead(url, OutProxyUrl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoImageRequest(url)
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := ProxiedHttpGet(url, OutProxyUrl)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -52,9 +66,9 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := DoImageRequest(imageUrl)
|
||||
response, err := ProxiedHttpGet(imageUrl, OutProxyUrl)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
@@ -66,7 +80,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
|
||||
var readData []byte
|
||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
|
||||
// 从response.Body读取更多的数据直到达到当前的限制
|
||||
additionalData := make([]byte, limit-int64(len(readData)))
|
||||
@@ -92,11 +106,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
SysLog(err.Error())
|
||||
config, err = webp.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
common.SysLog(err.Error())
|
||||
SysLog(err.Error())
|
||||
}
|
||||
format = "webp"
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
@@ -100,12 +99,10 @@ func LogQuota(quota int) string {
|
||||
}
|
||||
}
|
||||
|
||||
// LogJson 仅供测试使用 only for test
|
||||
func LogJson(ctx context.Context, msg string, obj any) {
|
||||
jsonStr, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
func LogQuotaF(quota float64) string {
|
||||
if DisplayInCurrencyEnabled {
|
||||
return fmt.Sprintf("$%.6f 额度", quota/QuotaPerUnit)
|
||||
} else {
|
||||
return fmt.Sprintf("%d 点额度", int64(quota))
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
@@ -23,50 +23,49 @@ const (
|
||||
|
||||
var defaultModelRatio = map[string]float64{
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4o-gizmo-*": 2.5,
|
||||
"gpt-4-all": 15,
|
||||
"gpt-4o-all": 15,
|
||||
"gpt-4": 15,
|
||||
//"gpt-4-0314": 15, //deprecated
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
//"gpt-4-32k-0314": 30, //deprecated
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-audio-preview": 1.25, // $0.0015 / 1K tokens
|
||||
"gpt-4o-audio-preview-2024-10-01": 1.25, // $0.0015 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5,
|
||||
"gpt-4o-realtime-preview": 2.5,
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
"o1-mini-2024-09-12": 1.5,
|
||||
"gpt-4o-mini": 0.075,
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
//"gpt-3.5-turbo-0301": 0.75, //deprecated
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"gpt-3.5-turbo-0125": 0.25,
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
//"text-davinci-002": 10,
|
||||
//"text-davinci-003": 10,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"g-*": 15,
|
||||
"gpt-4": 15,
|
||||
"gpt-4-0314": 15,
|
||||
"gpt-4-0613": 15,
|
||||
"gpt-4-32k": 30,
|
||||
"gpt-4-32k-0314": 30,
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
|
||||
"gpt-4o-mini-2024-07-18": 0.075,
|
||||
"chatgpt-4o-latest": 2.5, // $0.01 / 1K tokens
|
||||
"gpt-4o": 1.25, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
|
||||
"gpt-4o-2024-08-06": 1.25, // $0.01 / 1K tokens
|
||||
"gpt-4o-2024-11-20": 1.25, // $0.01 / 1K tokens
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 0.55, // $0.0011 / 1K tokens
|
||||
"o1-mini-2024-09-12": 0.55,
|
||||
"o3-mini": 0.55,
|
||||
"o3-mini-2025-01-31": 0.55,
|
||||
"gpt-4-turbo": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"gpt-3.5-turbo-0125": 0.25,
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
"text-davinci-002": 10,
|
||||
"text-davinci-003": 10,
|
||||
"text-davinci-edit-001": 10,
|
||||
"code-davinci-edit-001": 10,
|
||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||
@@ -88,11 +87,10 @@ var defaultModelRatio = map[string]float64{
|
||||
"claude-2.0": 4, // $8 / 1M tokens
|
||||
"claude-2.1": 4, // $8 / 1M tokens
|
||||
"claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens
|
||||
"claude-3-5-haiku-20241022": 0.5, // $1 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20241022": 1.5, // $3 / 1M tokens
|
||||
"claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens
|
||||
"claude-3-5-sonnet-20240620": 1.5,
|
||||
"claude-3-5-sonnet-20241022": 1.5,
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"claude-3-opus-20240229": 7.5, // $15 / 1M tokens
|
||||
"ERNIE-4.0-8K": 0.120 * RMB,
|
||||
"ERNIE-3.5-8K": 0.012 * RMB,
|
||||
"ERNIE-3.5-8K-0205": 0.024 * RMB,
|
||||
@@ -186,8 +184,10 @@ var defaultModelRatio = map[string]float64{
|
||||
var defaultModelPrice = map[string]float64{
|
||||
"suno_music": 0.1,
|
||||
"suno_lyrics": 0.01,
|
||||
"dall-e-2": 0.02,
|
||||
"dall-e-3": 0.04,
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"g-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
@@ -217,9 +217,10 @@ var (
|
||||
|
||||
var CompletionRatio map[string]float64 = nil
|
||||
var defaultCompletionRatio = map[string]float64{
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"gpt-4o-gizmo-*": 3,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-4-gizmo-*": 2,
|
||||
"g-*": 2,
|
||||
"gpt-4-all": 2,
|
||||
"gpt-4o-all": 2,
|
||||
}
|
||||
|
||||
func GetModelPriceMap() map[string]float64 {
|
||||
@@ -252,9 +253,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||
GetModelPriceMap()
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
price, ok := modelPriceMap[name]
|
||||
if !ok {
|
||||
@@ -295,6 +295,8 @@ func GetModelRatio(name string) float64 {
|
||||
GetModelRatioMap()
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
ratio, ok := modelRatioMap[name]
|
||||
if !ok {
|
||||
@@ -335,46 +337,49 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
|
||||
if strings.HasPrefix(name, "gpt-4o") {
|
||||
if name == "gpt-4o-2024-05-13" {
|
||||
return 3
|
||||
}
|
||||
return 4
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "o1-") {
|
||||
return 4
|
||||
}
|
||||
if name == "chatgpt-4o-latest" {
|
||||
return 3
|
||||
}
|
||||
if strings.Contains(name, "claude-instant-1") {
|
||||
return 3
|
||||
} else if strings.Contains(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.Contains(name, "claude-3") {
|
||||
return 5
|
||||
} else if strings.HasPrefix(name, "g-") {
|
||||
name = "g-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if name == "gpt-3.5-turbo" || strings.HasSuffix(name, "0125") {
|
||||
// https://openai.com/blog/new-embedding-models-and-api-updates
|
||||
// Updated GPT-3.5 Turbo model and lower pricing
|
||||
if strings.HasSuffix(name, "0125") {
|
||||
return 3
|
||||
}
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
if name == "gpt-3.5-turbo" {
|
||||
return 3
|
||||
}
|
||||
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && name != "gpt-4-all" && name != "gpt-4-gizmo-*" {
|
||||
if strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4-turbo") || "gpt-4o-2024-05-13" == name {
|
||||
return 3
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "gpt-4o") {
|
||||
return 4
|
||||
}
|
||||
|
||||
return 2
|
||||
}
|
||||
if "o1" == name || strings.HasPrefix(name, "o1-") {
|
||||
return 4
|
||||
}
|
||||
if "o3" == name || strings.HasPrefix(name, "o3-") {
|
||||
return 4
|
||||
}
|
||||
if name == "chatgpt-4o-latest" {
|
||||
return 3
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
return 3
|
||||
}
|
||||
@@ -421,34 +426,6 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func GetAudioRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 20
|
||||
}
|
||||
return 20
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
return 10
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
//func GetAudioPricePerMinute(name string) float64 {
|
||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
// return 0.06
|
||||
// }
|
||||
// return 0.06
|
||||
//}
|
||||
//
|
||||
//func GetAudioCompletionPricePerMinute(name string) float64 {
|
||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
||||
// return 0.24
|
||||
// }
|
||||
// return 0.24
|
||||
//}
|
||||
|
||||
func GetCompletionRatioMap() map[string]float64 {
|
||||
if CompletionRatio == nil {
|
||||
CompletionRatio = defaultCompletionRatio
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/net/proxy"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -206,3 +211,56 @@ func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func GetProxiedHttpClient(proxyUrl string) (*http.Client, error) {
|
||||
if "" == proxyUrl {
|
||||
return &http.Client{}, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(proxyUrl, "http") {
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(u),
|
||||
},
|
||||
}, nil
|
||||
} else if strings.HasPrefix(proxyUrl, "socks") {
|
||||
dialer, err := proxy.FromURL(u, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported proxy type")
|
||||
}
|
||||
|
||||
func ProxiedHttpGet(url, proxyUrl string) (*http.Response, error) {
|
||||
client, err := GetProxiedHttpClient(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
func ProxiedHttpHead(url, proxyUrl string) (*http.Response, error) {
|
||||
client, err := GetProxiedHttpClient(proxyUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.Head(url)
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package constant
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
var WorkerValidKey = ""
|
||||
|
||||
func EnableWorker() bool {
|
||||
return WorkerUrl != ""
|
||||
}
|
||||
@@ -102,22 +102,17 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, meta)
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
}
|
||||
if usageA == nil {
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
@@ -151,7 +146,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
Model: "", // this will be set later
|
||||
Stream: false,
|
||||
}
|
||||
if strings.HasPrefix(model, "o1-") {
|
||||
if "o1" == model || strings.HasPrefix(model, "o1-") {
|
||||
testRequest.MaxCompletionTokens = 1
|
||||
} else {
|
||||
testRequest.MaxTokens = 1
|
||||
|
||||
@@ -133,6 +133,8 @@ func GitHubOAuth(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(c.Query("aff"))
|
||||
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
@@ -143,7 +145,7 @@ func GitHubOAuth(c *gin.Context) {
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
247
controller/linuxdo.go
Normal file
247
controller/linuxdo.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LinuxDoOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type LinuxDoUser struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func getLinuxDoUserInfoByCode(code string) (*LinuxDoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(common.LinuxDoClientId + ":" + common.LinuxDoClientSecret))
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://connect.linux.do/oauth2/token", bytes.NewBufferString(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LinuxDoOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://connect.linux.do/api/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 LINUX DO 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var linuxdoUser LinuxDoUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&linuxdoUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if linuxdoUser.ID == 0 {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
if linuxdoUser.TrustLevel < common.LinuxDoMinLevel {
|
||||
return nil, errors.New("用户 LINUX DO 信任等级不足!")
|
||||
}
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxDoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
err := user.FillUserByLinuxDoId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
affCode := c.Query("aff")
|
||||
user.InviterId, _ = model.GetUserIdByAffCode(affCode)
|
||||
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if linuxdoUser.Name != "" {
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
} else {
|
||||
user.DisplayName = linuxdoUser.Username
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(user.InviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDoOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 LINUX DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxDoUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LinuxDoId: strconv.Itoa(linuxdoUser.ID),
|
||||
LinuxDoLevel: linuxdoUser.TrustLevel,
|
||||
}
|
||||
if model.IsLinuxDoIdAlreadyTaken(user.LinuxDoId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 LINUX DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LinuxDoId = strconv.Itoa(linuxdoUser.ID)
|
||||
user.LinuxDoLevel = linuxdoUser.TrustLevel
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(targetTimestamp)
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -233,7 +233,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -265,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDoOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDoClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
@@ -45,9 +47,9 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": constant.ServerAddress,
|
||||
"price": constant.Price,
|
||||
"min_topup": constant.MinTopUp,
|
||||
"server_address": common.ServerAddress,
|
||||
"stripe_unit_price": common.StripeUnitPrice,
|
||||
"min_topup": common.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -61,7 +63,7 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
||||
"payment_enabled": common.PaymentEnabled,
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
"chats": constant.Chats,
|
||||
},
|
||||
@@ -205,7 +207,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -50,6 +50,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDoOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDoClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 LINUX DO OAuth,请先填入 LINUX DO Client Id 以及 LINUX DO Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -39,15 +38,6 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
return err
|
||||
}
|
||||
|
||||
func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
switch relayMode {
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
@@ -144,67 +134,6 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许跨域
|
||||
},
|
||||
}
|
||||
|
||||
func WssRelay(c *gin.Context) {
|
||||
// 将 HTTP 连接升级为 WebSocket 连接
|
||||
|
||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
service.WssError(c, ws, openaiErr.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
@@ -212,13 +141,6 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
|
||||
99
controller/stripe.go
Normal file
99
controller/stripe.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v79"
|
||||
"github.com/stripe/stripe-go/v79/webhook"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := common.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if "" == referenceId {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
@@ -1,22 +1,20 @@
|
||||
package controller
|
||||
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/stripe/stripe-go/v79"
|
||||
"github.com/stripe/stripe-go/v79/checkout/session"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
type PayRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
@@ -27,201 +25,114 @@ type AmountRequest struct {
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
|
||||
return nil
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(common.StripeApiSecret, "sk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: constant.EpayId,
|
||||
Key: constant.EpayKey,
|
||||
}, constant.PayAddress)
|
||||
|
||||
stripe.Key = common.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(common.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(common.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(common.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return nil
|
||||
return "", err
|
||||
}
|
||||
return withUrl
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func getPayMoney(amount float64, group string) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// 别问为什么用float64,问就是这么点钱没必要
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * constant.Price * topupGroupRatio
|
||||
return payMoney
|
||||
func GetPayAmount(count float64) float64 {
|
||||
return count * common.StripeUnitPrice
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := constant.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
return minTopup
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
func RequestPayLink(c *gin.Context) {
|
||||
var req PayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.PaymentMethod != "stripe" {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
user, _ := model.GetUserById(id, false)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
var payType epay.PurchaseType
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = epay.Alipay
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = epay.WechatPay
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
Device: epay.PC,
|
||||
NotifyUrl: notifyUrl,
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), common.RandomString(4))
|
||||
referenceId := "ref_" + common.Sha1(reference)
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, int64(req.Amount))
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / int(common.QuotaPerUnit)
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
return r
|
||||
}, map[string]string{})
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
log.Println("易支付回调失败 未找到配置信息")
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*int(common.QuotaPerUnit))
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(topUp.Amount*int(common.QuotaPerUnit)), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payLink": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestAmount(c *gin.Context) {
|
||||
@@ -231,21 +142,23 @@ func RequestAmount(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
if !common.PaymentEnabled {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "管理员未开启在线支付"})
|
||||
return
|
||||
}
|
||||
if req.Amount < common.MinTopUp {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
user, _ := model.GetUserById(id, false)
|
||||
payMoney := GetPayAmount(float64(req.Amount))
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"payAmount": strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
"chargedAmount": strconv.FormatFloat(chargedMoney, 'f', 2, 64),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("group", user.Group)
|
||||
session.Set("linuxdo_enable", user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -574,7 +575,7 @@ func UpdateSelf(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
func HardDeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -583,7 +584,7 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
originUser, err := model.GetUserByIdUnscoped(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -607,9 +608,23 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
if !common.UserSelfDeletionEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "当前设置不允许用户自我删除账号",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
|
||||
@@ -2,18 +2,17 @@ version: '3.4'
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
# build: .
|
||||
image: pengzhile/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./data/new-api:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=newapi:123456@tcp(db:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
@@ -23,13 +22,22 @@ services:
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
- db
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
db:
|
||||
image: mysql:8.2.0
|
||||
container_name: mysql
|
||||
restart: always
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||
environment:
|
||||
TZ: Asia/Shanghai # 设置时区
|
||||
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||
MYSQL_USER: newapi # 创建专用用户
|
||||
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||
MYSQL_DATABASE: new-api # 自动创建数据库
|
||||
@@ -1,97 +0,0 @@
|
||||
package dto
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
RealtimeEventTypeConversationCreate = "conversation.item.create"
|
||||
RealtimeEventTypeResponseCreate = "response.create"
|
||||
RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
|
||||
)
|
||||
|
||||
const (
|
||||
RealtimeEventTypeResponseDone = "response.done"
|
||||
RealtimeEventTypeSessionUpdated = "session.updated"
|
||||
RealtimeEventTypeSessionCreated = "session.created"
|
||||
RealtimeEventResponseAudioDelta = "response.audio.delta"
|
||||
RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
|
||||
RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
|
||||
RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
|
||||
RealtimeEventConversationItemCreated = "conversation.item.created"
|
||||
)
|
||||
|
||||
type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
Usage *RealtimeUsage `json:"usage"`
|
||||
}
|
||||
|
||||
type RealtimeUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails InputTokenDetails `json:"input_token_details"`
|
||||
OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
}
|
||||
|
||||
type RealtimeSession struct {
|
||||
Modalities []string `json:"modalities"`
|
||||
Instructions string `json:"instructions"`
|
||||
Voice string `json:"voice"`
|
||||
InputAudioFormat string `json:"input_audio_format"`
|
||||
OutputAudioFormat string `json:"output_audio_format"`
|
||||
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
|
||||
TurnDetection interface{} `json:"turn_detection"`
|
||||
Tools []RealTimeTool `json:"tools"`
|
||||
ToolChoice string `json:"tool_choice"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
//MaxResponseOutputTokens int `json:"max_response_output_tokens"`
|
||||
}
|
||||
|
||||
type InputAudioTranscription struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type RealTimeTool struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters any `json:"parameters"`
|
||||
}
|
||||
|
||||
type RealtimeItem struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []RealtimeContent `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
CallId string `json:"call_id,omitempty"`
|
||||
}
|
||||
type RealtimeContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
}
|
||||
@@ -2,16 +2,16 @@ package dto
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
BestOf int `json:"best_of,omitempty"`
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Suffix string `json:"suffix,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
@@ -26,16 +26,16 @@ type GeneralOpenAIRequest struct {
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
LogitBias any `json:"logit_bias,omitempty"`
|
||||
LogProbs any `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
@@ -85,10 +85,9 @@ type Message struct {
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
InputAudio any `json:"input_audio,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
@@ -96,15 +95,9 @@ type MessageImageUrl struct {
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
type MessageInputAudio struct {
|
||||
Data string `json:"data"` //base64
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
@@ -177,19 +170,11 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
},
|
||||
})
|
||||
}
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: subObj["data"].(string),
|
||||
Format: subObj["format"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -42,9 +42,9 @@ type OpenAITextResponse struct {
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding any `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
5
go.mod
5
go.mod
@@ -6,7 +6,6 @@ go 1.21
|
||||
toolchain go1.22.4
|
||||
|
||||
require (
|
||||
github.com/Calcium-Ion/go-epay v0.0.2
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
@@ -27,8 +26,10 @@ require (
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/stripe/stripe-go/v79 v79.12.0
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/image v0.15.0
|
||||
golang.org/x/net v0.28.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -68,7 +69,6 @@ require (
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
@@ -80,7 +80,6 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/net v0.28.0 // indirect
|
||||
golang.org/x/sync v0.8.0 // indirect
|
||||
golang.org/x/sys v0.24.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -1,5 +1,3 @@
|
||||
github.com/Calcium-Ion/go-epay v0.0.2 h1:3knFBuaBFpHzsGeGQU/QxUqZSHh5s0+jGo0P62pJzWc=
|
||||
github.com/Calcium-Ion/go-epay v0.0.2/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
|
||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||
@@ -136,8 +134,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -180,6 +176,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stripe/stripe-go/v79 v79.12.0 h1:HQs/kxNEB3gYA7FnkSFkp0kSOeez0fsmCWev6SxftYs=
|
||||
github.com/stripe/stripe-go/v79 v79.12.0/go.mod h1:cuH6X0zC8peY6f1AubHwgJ/fJSn2dh5pfiCr6CjyKVU=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
@@ -205,6 +203,7 @@ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSO
|
||||
golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8=
|
||||
golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
|
||||
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -213,6 +212,7 @@ golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -27,6 +27,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
linuxDoEnable := session.Get("linuxdo_enable")
|
||||
useAccessToken := false
|
||||
if username == nil {
|
||||
// Check access token
|
||||
@@ -54,6 +55,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
linuxDoEnable = user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel
|
||||
useAccessToken = true
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -102,6 +104,14 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if nil != linuxDoEnable && !linuxDoEnable.(bool) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户 LINUX DO 信任等级不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if role.(int) < minRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -155,27 +165,8 @@ func RootAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func WssAuth(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// 先检测是否为ws
|
||||
if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
|
||||
// Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
|
||||
// read sk from Sec-WebSocket-Protocol
|
||||
key := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
parts := strings.Split(key, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "openai-insecure-api-key") {
|
||||
key = strings.TrimPrefix(part, "openai-insecure-api-key.")
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
@@ -210,6 +201,15 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
linuxDoEnabled, err := model.CacheIsLinuxDoEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !linuxDoEnabled {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "用户 LINUX DO 信任等级不足")
|
||||
return
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
|
||||
@@ -170,10 +170,6 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
modelRequest.Model = c.Query("model")
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
|
||||
@@ -205,6 +205,30 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheIsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsLinuxDoEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("linuxdo_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
linuxDoEnabled, err := IsLinuxDoEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if linuxDoEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("linuxdo_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set linuxdo enabled error: " + err.Error())
|
||||
}
|
||||
return linuxDoEnabled, err
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
@@ -269,9 +293,8 @@ func SyncChannelCache(frequency int) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
||||
model = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(model, "g-") {
|
||||
model = "g-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
|
||||
24
model/log.go
24
model/log.go
@@ -247,7 +247,25 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
return token
|
||||
}
|
||||
|
||||
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
|
||||
var total int64 = 0
|
||||
|
||||
for {
|
||||
if nil != ctx.Err() {
|
||||
return total, ctx.Err()
|
||||
}
|
||||
|
||||
result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
|
||||
if nil != result.Error {
|
||||
return total, result.Error
|
||||
}
|
||||
|
||||
total += result.RowsAffected
|
||||
|
||||
if result.RowsAffected < int64(limit) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
@@ -31,10 +31,12 @@ func InitOptionMap() {
|
||||
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
|
||||
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
|
||||
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
|
||||
common.OptionMap["LinuxDoOAuthEnabled"] = strconv.FormatBool(common.LinuxDoOAuthEnabled)
|
||||
common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
|
||||
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["UserSelfDeletionEnabled"] = strconv.FormatBool(common.UserSelfDeletionEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
@@ -60,18 +62,20 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
||||
common.OptionMap["OutProxyUrl"] = ""
|
||||
common.OptionMap["StripeApiSecret"] = common.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = common.StripeWebhookSecret
|
||||
common.OptionMap["StripePriceId"] = common.StripePriceId
|
||||
common.OptionMap["PaymentEnabled"] = strconv.FormatBool(common.PaymentEnabled)
|
||||
common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(common.StripeUnitPrice, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = constant.Chats2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoClientId"] = ""
|
||||
common.OptionMap["LinuxDoClientSecret"] = ""
|
||||
common.OptionMap["LinuxDoMinLevel"] = strconv.Itoa(common.LinuxDoMinLevel)
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
common.OptionMap["TelegramBotName"] = ""
|
||||
common.OptionMap["WeChatServerAddress"] = ""
|
||||
@@ -175,6 +179,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailVerificationEnabled = boolValue
|
||||
case "GitHubOAuthEnabled":
|
||||
common.GitHubOAuthEnabled = boolValue
|
||||
case "LinuxDoOAuthEnabled":
|
||||
common.LinuxDoOAuthEnabled = boolValue
|
||||
case "WeChatAuthEnabled":
|
||||
common.WeChatAuthEnabled = boolValue
|
||||
case "TelegramOAuthEnabled":
|
||||
@@ -183,6 +189,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.TurnstileCheckEnabled = boolValue
|
||||
case "RegisterEnabled":
|
||||
common.RegisterEnabled = boolValue
|
||||
case "UserSelfDeletionEnabled":
|
||||
common.UserSelfDeletionEnabled = boolValue
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "EmailAliasRestrictionEnabled":
|
||||
@@ -242,31 +250,35 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
constant.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
constant.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
constant.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
constant.PayAddress = value
|
||||
common.ServerAddress = value
|
||||
case "OutProxyUrl":
|
||||
common.OutProxyUrl = value
|
||||
case "Chats":
|
||||
err = constant.UpdateChatsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
constant.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
constant.EpayId = value
|
||||
case "EpayKey":
|
||||
constant.EpayKey = value
|
||||
case "Price":
|
||||
constant.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "StripeApiSecret":
|
||||
common.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
common.StripeWebhookSecret = value
|
||||
case "StripePriceId":
|
||||
common.StripePriceId = value
|
||||
case "PaymentEnabled":
|
||||
common.PaymentEnabled, _ = strconv.ParseBool(value)
|
||||
case "StripeUnitPrice":
|
||||
common.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
constant.MinTopUp, _ = strconv.Atoi(value)
|
||||
common.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
common.GitHubClientId = value
|
||||
case "GitHubClientSecret":
|
||||
common.GitHubClientSecret = value
|
||||
case "LinuxDoClientId":
|
||||
common.LinuxDoClientId = value
|
||||
case "LinuxDoClientSecret":
|
||||
common.LinuxDoClientSecret = value
|
||||
case "LinuxDoMinLevel":
|
||||
common.LinuxDoMinLevel, _ = strconv.Atoi(value)
|
||||
case "Footer":
|
||||
common.Footer = value
|
||||
case "SystemName":
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -325,7 +324,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
|
||||
prompt = "您的额度已用尽"
|
||||
}
|
||||
if email != "" {
|
||||
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
|
||||
err = common.SendEmail(prompt, email,
|
||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||
if err != nil {
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
type TopUp struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
Status string `json:"status"`
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Amount int `json:"amount"`
|
||||
Money float64 `json:"money"`
|
||||
TradeNo string `json:"trade_no" gorm:"unique"`
|
||||
CreateTime int64 `json:"create_time"`
|
||||
CompleteTime int64 `json:"complete_time"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func (topUp *TopUp) Insert() error {
|
||||
@@ -41,3 +49,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
|
||||
}
|
||||
return topUp
|
||||
}
|
||||
|
||||
func Recharge(referenceId string, customerId string) (err error) {
|
||||
if referenceId == "" {
|
||||
return errors.New("未提供支付单号")
|
||||
}
|
||||
|
||||
var quota float64
|
||||
topUp := &TopUp{}
|
||||
|
||||
refCol := "`trade_no`"
|
||||
if common.UsingPostgreSQL {
|
||||
refCol = `"trade_no"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
|
||||
if err != nil {
|
||||
return errors.New("充值订单不存在")
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
return errors.New("充值订单状态错误")
|
||||
}
|
||||
|
||||
topUp.CompleteTime = common.GetTimestamp()
|
||||
topUp.Status = common.TopUpStatusSuccess
|
||||
err = tx.Save(topUp).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
quota = topUp.Money * common.QuotaPerUnit
|
||||
err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return errors.New("充值失败," + err.Error())
|
||||
}
|
||||
|
||||
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.LogQuotaF(quota), topUp.Amount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"`
|
||||
LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
@@ -35,6 +37,7 @@ type User struct {
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
@@ -75,7 +78,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
|
||||
func GetMaxUserId() int {
|
||||
var user User
|
||||
DB.Last(&user)
|
||||
DB.Unscoped().Last(&user)
|
||||
return user.Id
|
||||
}
|
||||
|
||||
@@ -88,13 +91,18 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
|
||||
if group != "" {
|
||||
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
@@ -105,9 +113,9 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
@@ -130,6 +138,20 @@ func GetUserById(id int, selectAll bool) (*User, error) {
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("id 为空!")
|
||||
}
|
||||
user := User{Id: id}
|
||||
var err error = nil
|
||||
if selectAll {
|
||||
err = DB.Unscoped().First(&user, "id = ?", id).Error
|
||||
} else {
|
||||
err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func GetUserIdByAffCode(affCode string) (int, error) {
|
||||
if affCode == "" {
|
||||
return 0, errors.New("affCode 为空!")
|
||||
@@ -343,6 +365,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByLinuxDoId() error {
|
||||
if user.LinuxDoId == "" {
|
||||
return errors.New("LINUX DO id 为空!")
|
||||
}
|
||||
DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByWeChatId() error {
|
||||
if user.WeChatId == "" {
|
||||
return errors.New("WeChat id 为空!")
|
||||
@@ -374,6 +404,14 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool {
|
||||
return DB.Unscoped().Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsUsernameAlreadyTaken(username string) bool {
|
||||
return DB.Unscoped().Where("username = ?", username).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsTelegramIdAlreadyTaken(telegramId string) bool {
|
||||
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
@@ -415,6 +453,18 @@ func IsUserEnabled(userId int) (bool, error) {
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
func IsLinuxDoEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
}
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil
|
||||
}
|
||||
|
||||
func ValidateAccessToken(token string) (user *User) {
|
||||
if token == "" {
|
||||
return nil
|
||||
|
||||
@@ -12,13 +12,13 @@ type Adaptor interface {
|
||||
// Init IsStream bool
|
||||
Init(info *relaycommon.RelayInfo)
|
||||
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
|
||||
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
|
||||
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
if info.IsStream {
|
||||
req.Set("X-DashScope-SSE", "enable")
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
|
||||
@@ -36,8 +36,8 @@ type AliEmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
Embedding any `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
|
||||
@@ -105,7 +105,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
_, b64, err := common.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
common.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/relay/common"
|
||||
@@ -12,16 +11,14 @@ import (
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
||||
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
// multipart/form-data
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
// websocket
|
||||
} else {
|
||||
req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if info.IsStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -35,7 +32,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request failed: %w", err)
|
||||
}
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -58,7 +55,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
// set form data
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
|
||||
err = a.SetupRequestHeader(c, &req.Header, info)
|
||||
err = a.SetupRequestHeader(c, req, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
@@ -69,27 +66,6 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
|
||||
fullRequestURL, err := a.GetRequestURL(info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get request url failed: %w", err)
|
||||
}
|
||||
targetHeader := http.Header{}
|
||||
err = a.SetupRequestHeader(c, &targetHeader, info)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup request header failed: %w", err)
|
||||
}
|
||||
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
|
||||
}
|
||||
// send request body
|
||||
//all, err := io.ReadAll(requestBody)
|
||||
//err = service.WssString(c, targetConn, string(all))
|
||||
return targetConn, nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -50,9 +50,9 @@ type BaiduEmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding any `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
|
||||
@@ -47,14 +47,21 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("x-api-key", info.ApiKey)
|
||||
req.Header.Set("x-api-key", info.ApiKey)
|
||||
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
anthropicVersion = "2023-06-01"
|
||||
}
|
||||
req.Set("anthropic-version", anthropicVersion)
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
|
||||
anthropicBeta := c.Request.Header.Get("anthropic-beta")
|
||||
if "" != anthropicBeta {
|
||||
req.Header.Set("anthropic-beta", anthropicBeta)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,11 +80,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,6 @@ var ModelList = []string{
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
}
|
||||
|
||||
@@ -31,9 +31,13 @@ type ClaudeMessage struct {
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema"`
|
||||
Name string `json:"name"`
|
||||
Description any `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema,omitempty"`
|
||||
Type any `json:"type,omitempty"`
|
||||
DisplayHeightPx int `json:"display_height_px,omitempty"`
|
||||
DisplayWidthPx int `json:"display_width_px,omitempty"`
|
||||
DisplayNumber int `json:"display_number,omitempty"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
|
||||
@@ -63,23 +63,44 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||
|
||||
for _, tool := range textRequest.Tools {
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
claudeTool := Tool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
}
|
||||
claudeTool.InputSchema = make(map[string]interface{})
|
||||
claudeTool.InputSchema["type"] = params["type"].(string)
|
||||
claudeTool.InputSchema["properties"] = params["properties"]
|
||||
claudeTool.InputSchema["required"] = params["required"]
|
||||
for s, a := range params {
|
||||
if s == "type" || s == "properties" || s == "required" {
|
||||
continue
|
||||
}
|
||||
claudeTool.InputSchema[s] = a
|
||||
}
|
||||
claudeTools = append(claudeTools, claudeTool)
|
||||
claudeTool := Tool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Type: tool.Type,
|
||||
}
|
||||
|
||||
if "function" != tool.Type {
|
||||
claudeTool.Description = nil
|
||||
}
|
||||
|
||||
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||
if "function" == tool.Type {
|
||||
claudeTool.Type = nil
|
||||
|
||||
claudeTool.InputSchema = make(map[string]interface{})
|
||||
claudeTool.InputSchema["type"] = params["type"].(string)
|
||||
claudeTool.InputSchema["properties"] = params["properties"]
|
||||
claudeTool.InputSchema["required"] = params["required"]
|
||||
for s, a := range params {
|
||||
if s == "type" || s == "properties" || s == "required" {
|
||||
continue
|
||||
}
|
||||
claudeTool.InputSchema[s] = a
|
||||
}
|
||||
} else {
|
||||
if val, ok := params["display_height_px"]; ok {
|
||||
claudeTool.DisplayHeightPx = int(val.(float64))
|
||||
}
|
||||
if val, ok := params["display_width_px"]; ok {
|
||||
claudeTool.DisplayWidthPx = int(val.(float64))
|
||||
}
|
||||
if val, ok := params["display_number"]; ok {
|
||||
claudeTool.DisplayNumber = int(val.(float64))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
claudeTools = append(claudeTools, claudeTool)
|
||||
}
|
||||
|
||||
claudeRequest := ClaudeRequest{
|
||||
@@ -91,6 +112,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
TopK: textRequest.TopK,
|
||||
Stream: textRequest.Stream,
|
||||
Tools: claudeTools,
|
||||
ToolChoice: textRequest.ToolChoice,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = 4096
|
||||
@@ -225,11 +247,11 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
|
||||
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -509,7 +531,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
@@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
|
||||
@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return requestOpenAI2Cohere(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return requestConvertRerank2Cohere(request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = difyStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
return nil, usage
|
||||
|
||||
@@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("x-goog-api-key", info.ApiKey)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -86,7 +86,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
@@ -94,7 +94,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
_, format, base64String, err := common.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = jinaRerankHandler(c, resp)
|
||||
} else if info.RelayMode == constant.RelayModeEmbeddings {
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
mistralReq := requestOpenAI2Mistral(*request)
|
||||
//common.LogJson(c, "body", mistralReq)
|
||||
return mistralReq, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package mistral
|
||||
|
||||
var ModelList = []string{
|
||||
"open-mistral-7b",
|
||||
"open-mixtral-8x7b",
|
||||
"mistral-small-latest",
|
||||
"mistral-medium-latest",
|
||||
"mistral-large-latest",
|
||||
"mistral-embed",
|
||||
}
|
||||
|
||||
var ChannelName = "mistral"
|
||||
@@ -1,40 +0,0 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
func requestOpenAI2Mistral(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
|
||||
messages := make([]dto.Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if !message.IsStringContent() {
|
||||
mediaMessages := message.ParseContent()
|
||||
for j, mediaMessage := range mediaMessages {
|
||||
if mediaMessage.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
|
||||
mediaMessage.ImageUrl = imageUrl.Url
|
||||
mediaMessages[j] = mediaMessage
|
||||
}
|
||||
}
|
||||
messageRaw, _ := json.Marshal(mediaMessages)
|
||||
message.Content = messageRaw
|
||||
}
|
||||
messages = append(messages, dto.Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
})
|
||||
}
|
||||
return &dto.GeneralOpenAIRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxTokens: request.MaxTokens,
|
||||
Tools: request.Tools,
|
||||
ToolChoice: request.ToolChoice,
|
||||
}
|
||||
}
|
||||
@@ -31,13 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
return info.BaseUrl + "/api/embed", nil
|
||||
return info.BaseUrl + "/api/embeddings", nil
|
||||
default:
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
@@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -35,7 +35,7 @@ type OllamaEmbeddingRequest struct {
|
||||
}
|
||||
|
||||
type OllamaEmbeddingResponse struct {
|
||||
Embedding any `json:"embedding,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
}
|
||||
|
||||
@@ -31,13 +31,6 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
// trim https
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
baseUrl = strings.TrimPrefix(baseUrl, "http://")
|
||||
baseUrl = "wss://" + baseUrl
|
||||
info.BaseUrl = baseUrl
|
||||
}
|
||||
switch info.ChannelType {
|
||||
case common.ChannelTypeAzure:
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
@@ -47,10 +40,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
model_ := info.UpstreamModelName
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
case common.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
@@ -63,34 +54,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, header)
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
header.Set("api-key", info.ApiKey)
|
||||
req.Header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
if swp != "" {
|
||||
items := []string{
|
||||
"realtime",
|
||||
"openai-insecure-api-key." + info.ApiKey,
|
||||
"openai-beta.realtime-v1",
|
||||
}
|
||||
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
|
||||
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
|
||||
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
|
||||
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
||||
} else {
|
||||
header.Set("openai-beta", "realtime=v1")
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
} else {
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
||||
// req.Header.Set("X-Title", "One API")
|
||||
@@ -105,7 +78,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o1-") {
|
||||
if "o1" == request.Model || strings.HasPrefix(request.Model, "o1-") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
@@ -158,20 +131,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRealtime:
|
||||
err, usage = OpenaiRealtimeHandler(c, info)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeAudioTranslation:
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
package openai
|
||||
|
||||
var ModelList = []string{
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0613",
|
||||
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-vision-preview",
|
||||
"chatgpt-4o-latest",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06",
|
||||
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-08-06",
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
|
||||
"text-moderation-latest", "text-moderation-stable",
|
||||
"text-davinci-edit-001",
|
||||
"davinci-002", "babbage-002",
|
||||
"dall-e-3",
|
||||
"dall-e-2", "dall-e-3",
|
||||
"whisper-1",
|
||||
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -232,7 +231,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
|
||||
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -325,7 +324,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
|
||||
usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return nil, usage
|
||||
}
|
||||
@@ -374,210 +373,3 @@ func getTextFromJSON(body []byte) (string, error) {
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
|
||||
info.IsStream = true
|
||||
clientConn := info.ClientWs
|
||||
targetConn := info.TargetWs
|
||||
|
||||
clientClosed := make(chan struct{})
|
||||
targetClosed := make(chan struct{})
|
||||
sendChan := make(chan []byte, 100)
|
||||
receiveChan := make(chan []byte, 100)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
usage := &dto.RealtimeUsage{}
|
||||
localUsage := &dto.RealtimeUsage{}
|
||||
sumUsage := &dto.RealtimeUsage{}
|
||||
|
||||
gopool.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from client: %v", err)
|
||||
}
|
||||
close(clientClosed)
|
||||
return
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
|
||||
if realtimeEvent.Session != nil {
|
||||
if realtimeEvent.Session.Tools != nil {
|
||||
info.RealtimeTools = realtimeEvent.Session.Tools
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
|
||||
err = service.WssString(c, targetConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case sendChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
gopool.Go(func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.Done():
|
||||
return
|
||||
default:
|
||||
_, message, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
errChan <- fmt.Errorf("error reading from target: %v", err)
|
||||
}
|
||||
close(targetClosed)
|
||||
return
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
|
||||
realtimeUsage := realtimeEvent.Response.Usage
|
||||
if realtimeUsage != nil {
|
||||
usage.TotalTokens += realtimeUsage.TotalTokens
|
||||
usage.InputTokens += realtimeUsage.InputTokens
|
||||
usage.OutputTokens += realtimeUsage.OutputTokens
|
||||
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
|
||||
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
|
||||
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
|
||||
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
|
||||
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
|
||||
err := preConsumeUsage(c, info, usage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
usage = &dto.RealtimeUsage{}
|
||||
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
info.IsFirstRequest = false
|
||||
localUsage.InputTokens += textToken + audioToken
|
||||
localUsage.InputTokenDetails.TextTokens += textToken
|
||||
localUsage.InputTokenDetails.AudioTokens += audioToken
|
||||
err = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error consume usage: %v", err)
|
||||
return
|
||||
}
|
||||
// 本次计费完成,清除
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
if realtimeSession != nil {
|
||||
// update audio format
|
||||
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
|
||||
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
|
||||
}
|
||||
} else {
|
||||
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error counting text token: %v", err)
|
||||
return
|
||||
}
|
||||
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
|
||||
localUsage.TotalTokens += textToken + audioToken
|
||||
localUsage.OutputTokens += textToken + audioToken
|
||||
localUsage.OutputTokenDetails.TextTokens += textToken
|
||||
localUsage.OutputTokenDetails.AudioTokens += audioToken
|
||||
}
|
||||
|
||||
err = service.WssString(c, clientConn, string(message))
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case receiveChan <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
select {
|
||||
case <-clientClosed:
|
||||
case <-targetClosed:
|
||||
case err := <-errChan:
|
||||
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
|
||||
common.LogError(c, "realtime error: "+err.Error())
|
||||
case <-c.Done():
|
||||
}
|
||||
|
||||
if usage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, usage, sumUsage)
|
||||
}
|
||||
|
||||
if localUsage.TotalTokens != 0 {
|
||||
_ = preConsumeUsage(c, info, localUsage, sumUsage)
|
||||
}
|
||||
|
||||
// check usage total tokens, if 0, use local usage
|
||||
|
||||
return nil, sumUsage
|
||||
}
|
||||
|
||||
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
|
||||
totalUsage.TotalTokens += usage.TotalTokens
|
||||
totalUsage.InputTokens += usage.InputTokens
|
||||
totalUsage.OutputTokens += usage.OutputTokens
|
||||
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
|
||||
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
|
||||
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
|
||||
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
|
||||
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
|
||||
// clear usage
|
||||
err := service.PreWssConsumeQuota(ctx, info, usage)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("x-goog-api-key", info.ApiKey)
|
||||
req.Header.Set("x-goog-api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
|
||||
@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
@@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("invalid relay mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = siliconflowRerankHandler(c, resp)
|
||||
|
||||
@@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", a.Sign)
|
||||
req.Set("X-TC-Action", a.Action)
|
||||
req.Set("X-TC-Version", a.Version)
|
||||
req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
req.Header.Set("Authorization", a.Sign)
|
||||
req.Header.Set("X-TC-Action", a.Action)
|
||||
req.Header.Set("X-TC-Version", a.Version)
|
||||
req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
|
||||
@@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
switch a.RequestMode {
|
||||
case RequestModeClaude:
|
||||
|
||||
@@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
return nil
|
||||
}
|
||||
@@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
// xunfei's request is not http request, so we don't need to do anything here
|
||||
dummyResp := &http.Response{}
|
||||
dummyResp.StatusCode = http.StatusOK
|
||||
return dummyResp, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
splits := strings.Split(info.ApiKey, "|")
|
||||
if len(splits) != 3 {
|
||||
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
|
||||
@@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Set("Authorization", token)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = zhipuStreamHandler(c, resp)
|
||||
} else {
|
||||
|
||||
@@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
token := getZhipuToken(info.ApiKey)
|
||||
req.Set("Authorization", token)
|
||||
req.Header.Set("Authorization", token)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
} else {
|
||||
|
||||
@@ -2,9 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,7 +21,6 @@ type RelayInfo struct {
|
||||
ApiType int
|
||||
IsStream bool
|
||||
IsPlayground bool
|
||||
UsePrice bool
|
||||
RelayMode int
|
||||
UpstreamModelName string
|
||||
OriginModelName string
|
||||
@@ -35,21 +32,6 @@ type RelayInfo struct {
|
||||
BaseUrl string
|
||||
SupportStreamOptions bool
|
||||
ShouldIncludeUsage bool
|
||||
ClientWs *websocket.Conn
|
||||
TargetWs *websocket.Conn
|
||||
InputAudioFormat string
|
||||
OutputAudioFormat string
|
||||
RealtimeTools []dto.RealTimeTool
|
||||
IsFirstRequest bool
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.ClientWs = ws
|
||||
info.InputAudioFormat = "pcm16"
|
||||
info.OutputAudioFormat = "pcm16"
|
||||
info.IsFirstRequest = true
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
|
||||
@@ -25,7 +25,6 @@ const (
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -73,8 +72,6 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeSiliconFlow
|
||||
case common.ChannelTypeVertexAi:
|
||||
apiType = APITypeVertexAi
|
||||
case common.ChannelTypeMistral:
|
||||
apiType = APITypeMistral
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
||||
@@ -38,8 +38,6 @@ const (
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
RelayModeRealtime
|
||||
)
|
||||
|
||||
func Path2RelayMode(path string) int {
|
||||
@@ -66,8 +64,6 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(path, "/v1/rerank") {
|
||||
relayMode = RelayModeRerank
|
||||
} else if strings.HasPrefix(path, "/v1/realtime") {
|
||||
relayMode = RelayModeRealtime
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -92,11 +92,6 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
@@ -127,20 +122,19 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
var httpResp *http.Response
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
@@ -148,7 +142,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||
postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -149,24 +149,22 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var httpResp *http.Response
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr := service.RelayErrorHandler(httpResp)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
_, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
_, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
|
||||
@@ -31,7 +31,7 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(midjourneyTask.ImageUrl)
|
||||
resp, err := common.ProxiedHttpGet(midjourneyTask.ImageUrl, common.OutProxyUrl)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "http_get_image_failed",
|
||||
@@ -112,7 +112,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
@@ -131,11 +131,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
includeUsage := false
|
||||
// 判断用户是否需要返回使用情况
|
||||
if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
|
||||
@@ -184,30 +180,30 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,6 +215,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
prompts := textRequest.Prompt
|
||||
switch v := prompts.(type) {
|
||||
case string:
|
||||
prompts = v + textRequest.Suffix
|
||||
case []string:
|
||||
prompts = append(v, textRequest.Suffix)
|
||||
}
|
||||
|
||||
promptTokens, err = service.CountTokenInput(prompts, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
@@ -367,9 +372,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
} else if strings.HasPrefix(logModel, "g-") {
|
||||
logModel = "g-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
"one-api/relay/channel/jina"
|
||||
"one-api/relay/channel/mistral"
|
||||
"one-api/relay/channel/ollama"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/palm"
|
||||
@@ -69,8 +68,6 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &siliconflow.Adaptor{}
|
||||
case constant.APITypeVertexAi:
|
||||
return &vertex.Adaptor{}
|
||||
case constant.APITypeMistral:
|
||||
return &mistral.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
return token
|
||||
}
|
||||
|
||||
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
var rerankRequest *dto.RerankRequest
|
||||
@@ -79,12 +79,6 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
@@ -105,24 +99,23 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
openaiErr := service.RelayErrorHandler(resp)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||
// _, p, err := ws.ReadMessage()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// realtimeEvent := &dto.RealtimeEvent{}
|
||||
// err = json.Unmarshal(p, realtimeEvent)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // save the original request
|
||||
// if realtimeEvent.Session == nil {
|
||||
// return nil, errors.New("session object is nil")
|
||||
// }
|
||||
// c.Set("first_wss_request", p)
|
||||
// return realtimeEvent, nil
|
||||
//}
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
//realtimeEvent, err := getAndValidateWssRequest(c, ws)
|
||||
//if err != nil {
|
||||
// common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
|
||||
// return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
//}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
//isModelMapped := false
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[relayInfo.OriginModelName] != "" {
|
||||
relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
|
||||
// set upstream model name
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||
}
|
||||
}()
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
//var requestBody io.Reader
|
||||
//firstWssRequest, _ := c.Get("first_wss_request")
|
||||
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, nil)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
relayInfo.TargetWs = resp.(*websocket.Conn)
|
||||
defer relayInfo.TargetWs.Close()
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
|
||||
if openaiErr != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
|
||||
// var promptTokens int
|
||||
// var err error
|
||||
// switch info.RelayMode {
|
||||
// default:
|
||||
// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
|
||||
// }
|
||||
// info.PromptTokens = promptTokens
|
||||
// return promptTokens, err
|
||||
//}
|
||||
|
||||
//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
|
||||
// var err error
|
||||
// switch info.RelayMode {
|
||||
// case relayconstant.RelayModeChatCompletions:
|
||||
// err = service.CheckSensitiveMessages(textRequest.Messages)
|
||||
// case relayconstant.RelayModeCompletions:
|
||||
// err = service.CheckSensitiveInput(textRequest.Prompt)
|
||||
// case relayconstant.RelayModeModerations:
|
||||
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||
// case relayconstant.RelayModeEmbeddings:
|
||||
// err = service.CheckSensitiveInput(textRequest.Input)
|
||||
// }
|
||||
// return err
|
||||
//}
|
||||
@@ -18,13 +18,14 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
apiRouter.GET("/about", controller.GetAbout)
|
||||
//apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/midjourney", controller.GetMidjourney)
|
||||
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
|
||||
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
|
||||
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxDoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
@@ -32,13 +33,14 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
|
||||
|
||||
apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
|
||||
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
|
||||
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
|
||||
userRoute.GET("/logout", controller.Logout)
|
||||
userRoute.GET("/epay/notify", controller.EpayNotify)
|
||||
userRoute.GET("/groups", controller.GetUserGroups)
|
||||
|
||||
selfRoute := userRoute.Group("/")
|
||||
@@ -51,8 +53,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.POST("/topup", controller.TopUp)
|
||||
selfRoute.POST("/pay", controller.RequestEpay)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestPayLink)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||
}
|
||||
@@ -66,7 +68,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
adminRoute.POST("/", controller.CreateUser)
|
||||
adminRoute.POST("/manage", controller.ManageUser)
|
||||
adminRoute.PUT("/", controller.UpdateUser)
|
||||
adminRoute.DELETE("/:id", controller.DeleteUser)
|
||||
adminRoute.DELETE("/:id", controller.HardDeleteUser)
|
||||
}
|
||||
}
|
||||
optionRoute := apiRouter.Group("/option")
|
||||
|
||||
@@ -22,41 +22,32 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
playgroundRouter.POST("/chat/completions", controller.Playground)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth())
|
||||
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
// WebSocket 路由
|
||||
wsRouter := relayV1Router.Group("")
|
||||
wsRouter.Use(middleware.Distribute())
|
||||
wsRouter.GET("/realtime", controller.WssRelay)
|
||||
}
|
||||
{
|
||||
//http router
|
||||
httpRouter := relayV1Router.Group("")
|
||||
httpRouter.Use(middleware.Distribute())
|
||||
httpRouter.POST("/completions", controller.Relay)
|
||||
httpRouter.POST("/chat/completions", controller.Relay)
|
||||
httpRouter.POST("/edits", controller.Relay)
|
||||
httpRouter.POST("/images/generations", controller.Relay)
|
||||
httpRouter.POST("/images/edits", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/embeddings", controller.Relay)
|
||||
httpRouter.POST("/engines/:model/embeddings", controller.Relay)
|
||||
httpRouter.POST("/audio/transcriptions", controller.Relay)
|
||||
httpRouter.POST("/audio/translations", controller.Relay)
|
||||
httpRouter.POST("/audio/speech", controller.Relay)
|
||||
httpRouter.GET("/files", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/files", controller.RelayNotImplemented)
|
||||
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/files/:id", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/moderations", controller.Relay)
|
||||
httpRouter.POST("/rerank", controller.Relay)
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||
relayV1Router.POST("/audio/speech", controller.Relay)
|
||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/rerank", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
|
||||
audioData, err := base64.StdEncoding.DecodeString(audioBase64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("base64 decode error: %v", err)
|
||||
}
|
||||
|
||||
var samplesCount int
|
||||
var sampleRate int
|
||||
|
||||
switch format {
|
||||
case "pcm16":
|
||||
samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
|
||||
sampleRate = 24000 // 24kHz
|
||||
case "g711_ulaw", "g711_alaw":
|
||||
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||
sampleRate = 8000 // 8kHz
|
||||
default:
|
||||
samplesCount = len(audioData) // 8位 = 1字节每样本
|
||||
sampleRate = 8000 // 8kHz
|
||||
}
|
||||
|
||||
duration = float64(samplesCount) / float64(sampleRate)
|
||||
return duration, nil
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if constant.CustomCallbackAddress == "" {
|
||||
return constant.ServerAddress
|
||||
}
|
||||
return constant.CustomCallbackAddress
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
)
|
||||
|
||||
@@ -18,13 +17,3 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["admin_info"] = adminInfo
|
||||
return other
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
info["ws"] = true
|
||||
info["audio_input"] = usage.InputTokenDetails.AudioTokens
|
||||
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
|
||||
info["text_input"] = usage.InputTokenDetails.TextTokens
|
||||
info["text_output"] = usage.OutputTokenDetails.TextTokens
|
||||
return info
|
||||
}
|
||||
|
||||
140
service/quota.go
140
service/quota.go
@@ -1,140 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
|
||||
if relayInfo.UsePrice {
|
||||
return nil
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelName := relayInfo.UpstreamModelName
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio := common.GetModelRatio(modelName)
|
||||
|
||||
ratio := groupRatio * modelRatio
|
||||
|
||||
quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
|
||||
quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
|
||||
|
||||
quota = int(math.Round(float64(quota) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
|
||||
if userQuota < quota {
|
||||
return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
}
|
||||
|
||||
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64,
|
||||
groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
|
||||
quota := 0
|
||||
if !usePrice {
|
||||
quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio))
|
||||
quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := usage.TotalTokens
|
||||
var logContent string
|
||||
if !usePrice {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
} else {
|
||||
//if sensitiveResp != nil {
|
||||
// logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
|
||||
//}
|
||||
//quotaDelta := quota - preConsumedQuota
|
||||
//if quotaDelta != 0 {
|
||||
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// if err != nil {
|
||||
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
// }
|
||||
//}
|
||||
|
||||
//err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
//if err != nil {
|
||||
// common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
//}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -43,47 +42,11 @@ func Done(c *gin.Context) {
|
||||
_ = StringData(c, "[DONE]")
|
||||
}
|
||||
|
||||
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
|
||||
return ws.WriteMessage(1, []byte(str))
|
||||
}
|
||||
|
||||
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
|
||||
jsonData, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshalling object: %w", err)
|
||||
}
|
||||
if ws == nil {
|
||||
common.LogError(c, "websocket connection is nil")
|
||||
return errors.New("websocket connection is nil")
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
|
||||
return ws.WriteMessage(1, jsonData)
|
||||
}
|
||||
|
||||
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
|
||||
errorObj := &dto.RealtimeEvent{
|
||||
Type: "error",
|
||||
EventId: GetLocalRealtimeID(c),
|
||||
Error: &openaiError,
|
||||
}
|
||||
_ = WssObject(c, ws, errorObj)
|
||||
}
|
||||
|
||||
func GetResponseID(c *gin.Context) string {
|
||||
logID := c.GetString(common.RequestIdKey)
|
||||
logID := c.GetString("X-Oneapi-Request-Id")
|
||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||
}
|
||||
|
||||
func GetLocalRealtimeID(c *gin.Context) string {
|
||||
logID := c.GetString(common.RequestIdKey)
|
||||
return fmt.Sprintf("evt_%s", logID)
|
||||
}
|
||||
|
||||
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
|
||||
return &dto.ChatCompletionsStreamResponse{
|
||||
Id: id,
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
var cl200kTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
@@ -32,19 +30,22 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
||||
|
||||
gpt4oTokenEncoder, err := tiktoken.EncodingForModel("gpt-4o")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.GetDefaultModelRatioMap() {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if strings.HasPrefix(model, "chatgpt-4o") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if "o1" == model || strings.HasPrefix(model, "o1") {
|
||||
tokenEncoderMap[model] = gpt4oTokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = cl200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
@@ -52,13 +53,6 @@ func InitTokenEncoders() {
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
|
||||
return cl200kTokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
@@ -69,13 +63,12 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = getModelDefaultTokenEncoder(model)
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
|
||||
return getModelDefaultTokenEncoder(model)
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
@@ -112,10 +105,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
var err error
|
||||
var format string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
config, format, err = DecodeUrlImageData(imageUrl.Url)
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, _, err = DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, _, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -192,72 +186,6 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
|
||||
audioToken := 0
|
||||
textToken := 0
|
||||
switch request.Type {
|
||||
case dto.RealtimeEventTypeSessionUpdate:
|
||||
if request.Session != nil {
|
||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += msgTokens
|
||||
}
|
||||
case dto.RealtimeEventResponseAudioDelta:
|
||||
// count audio token
|
||||
atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||
}
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||
// count text token
|
||||
tkm, err := CountTextToken(request.Delta, model)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||
}
|
||||
textToken += tkm
|
||||
case dto.RealtimeEventInputAudioBufferAppend:
|
||||
// count audio token
|
||||
atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting audio token: %v", err)
|
||||
}
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventConversationItemCreated:
|
||||
if request.Item != nil {
|
||||
switch request.Item.Type {
|
||||
case "message":
|
||||
for _, content := range request.Item.Content {
|
||||
if content.Type == "input_text" {
|
||||
tokens, err := CountTextToken(content.Text, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dto.RealtimeEventTypeResponseDone:
|
||||
// count tools token
|
||||
if !info.IsFirstRequest {
|
||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||
for _, tool := range info.RealtimeTools {
|
||||
toolTokens, err := CountTokenInput(tool, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
textToken += 8
|
||||
textToken += toolTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
@@ -290,7 +218,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
} else {
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
if m.Type == "image_url" {
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||
if err != nil {
|
||||
@@ -298,9 +226,6 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
}
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
} else if m.Type == dto.ContentTypeInputAudio {
|
||||
// TODO: 音频token数量计算
|
||||
tokenNum += 100
|
||||
} else {
|
||||
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
}
|
||||
@@ -315,13 +240,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTextToken(v, model)
|
||||
return CountTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTextToken(text, model)
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
return CountTokenInput(fmt.Sprintf("%v", input), model)
|
||||
}
|
||||
@@ -343,44 +268,16 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) (int, error) {
|
||||
func CountAudioToken(text string, model string) (int, error) {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text), nil
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
}
|
||||
|
||||
func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
}
|
||||
duration, err := parseAudio(audioBase64, audioFormat)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(duration / 60 * 100 / 0.06), nil
|
||||
}
|
||||
|
||||
func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
|
||||
if audioBase64 == "" {
|
||||
return 0, nil
|
||||
}
|
||||
duration, err := parseAudio(audioBase64, audioFormat)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(duration / 60 * 200 / 0.24), nil
|
||||
}
|
||||
|
||||
//func CountAudioToken(sec float64, audioType string) {
|
||||
// if audioType == "input" {
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTextToken(text string, model string) (int, error) {
|
||||
// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTokenText(text string, model string) (int, error) {
|
||||
var err error
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm, err := CountTextToken(responseText, modeName)
|
||||
ctkm, err := CountTokenText(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if constant.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
workerUrl := constant.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
|
||||
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
5
web/.gitignore
vendored
5
web/.gitignore
vendored
@@ -10,6 +10,7 @@
|
||||
|
||||
# production
|
||||
/build
|
||||
/dist
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
@@ -21,6 +22,4 @@
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.idea
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
.idea/
|
||||
@@ -1 +1 @@
|
||||
module.exports = require("@so1ve/prettier-config");
|
||||
module.exports = require('@so1ve/prettier-config');
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"prettier": "^3.0.0",
|
||||
"typescript": "4.4.2",
|
||||
"vite": "^5.2.0"
|
||||
"vite": "^5.4.14"
|
||||
},
|
||||
"prettier": {
|
||||
"singleQuote": true,
|
||||
|
||||
5489
web/pnpm-lock.yaml
generated
5489
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user