From d88e07fd9a6a26d241076ccd295846d32da2d346 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Fri, 31 Jan 2025 15:15:59 +0800 Subject: [PATCH 1/4] feat: add deepseek-reasoner & gemini-2.0-flash-thinking-exp-01-21 (#2045) * feat: add MILLI_USD constant and update pricing for deepseek services * feat: add support for new Gemini model version 'gemini-2.0-flash-thinking-exp-01-21' --- relay/adaptor/gemini/adaptor.go | 8 ++++-- relay/adaptor/gemini/constants.go | 2 +- relay/adaptor/vertexai/gemini/adapter.go | 3 +- relay/billing/ratio/model.go | 35 ++++++++++++++---------- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index a86fde40..edae1791 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -7,7 +7,6 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" channelhelper "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -24,8 +23,11 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - defaultVersion := config.GeminiVersion - if meta.ActualModelName == "gemini-2.0-flash-exp" { + var defaultVersion string + switch meta.ActualModelName { + case "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp", + "gemini-2.0-flash-thinking-exp-01-21": defaultVersion = "v1beta" } diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go index 9d1cbc4a..381d0c12 100644 --- a/relay/adaptor/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -7,5 +7,5 @@ var ModelList = []string{ "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", "gemini-2.0-flash-exp", - "gemini-2.0-flash-thinking-exp", + "gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", } diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index b5377875..1240ea5b 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -18,7 +18,8 @@ var ModelList = []string{ "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-1.5-pro-002", "gemini-1.5-flash-002", - "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", + "gemini-2.0-flash-exp", + "gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", } type Adaptor struct { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index f83aa70c..7fe08506 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -9,9 +9,10 @@ import ( ) const ( - USD2RMB = 7 - USD = 500 // $0.002 = 1 -> $1 = 500 - RMB = USD / USD2RMB + USD2RMB = 7 + USD = 500 // $0.002 = 1 -> $1 = 500 + MILLI_USD = 1.0 / 1000 * USD + RMB = USD / USD2RMB ) // ModelRatio @@ -109,15 +110,16 @@ var ModelRatio = map[string]float64{ "bge-large-en": 0.002 * RMB, "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro": 1, - "gemini-1.5-pro": 1, - "gemini-1.5-pro-001": 1, - "gemini-1.5-flash": 1, - "gemini-1.5-flash-001": 1, - "gemini-2.0-flash-exp": 1, - "gemini-2.0-flash-thinking-exp": 1, - "aqa": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro": 1, + "gemini-1.5-pro": 1, + "gemini-1.5-pro-001": 1, + "gemini-1.5-flash": 1, + "gemini-1.5-flash-001": 1, + "gemini-2.0-flash-exp": 1, + "gemini-2.0-flash-thinking-exp": 1, + "gemini-2.0-flash-thinking-exp-01-21": 1, + "aqa": 1, // https://open.bigmodel.cn/pricing "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB, @@ -279,8 +281,8 @@ var ModelRatio = map[string]float64{ "command-r": 0.5 / 1000 * USD, "command-r-plus": 3.0 / 1000 * USD, // https://platform.deepseek.com/api-docs/pricing/ - "deepseek-chat": 1.0 / 1000 * RMB, - "deepseek-coder": 1.0 / 1000 * RMB, + "deepseek-chat": 0.14 * MILLI_USD, + "deepseek-reasoner": 0.55 * MILLI_USD, // https://www.deepl.com/pro?cta=header-prices "deepl-zh": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD, @@ -337,6 +339,11 @@ var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens + // deepseek + "deepseek-chat": 0.28 / 0.14, + "deepseek-reasoner": 2.19 / 0.55, } var ( From 605bb06667014412e72e6cffd415ed8695a2e736 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 31 Jan 2025 16:00:53 +0800 Subject: [PATCH 2/4] feat: update logger --- common/logger/logger.go | 83 ++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/common/logger/logger.go b/common/logger/logger.go index d1022932..c5797217 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -7,19 +7,25 @@ import ( "log" "os" "path/filepath" + "runtime" + "strings" "sync" "time" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" ) +type loggerLevel string + const ( - loggerDEBUG = "DEBUG" - loggerINFO = "INFO" - loggerWarn = "WARN" - loggerError = "ERR" + loggerDEBUG loggerLevel = "DEBUG" + loggerINFO loggerLevel = "INFO" + loggerWarn loggerLevel = "WARN" + loggerError loggerLevel = "ERROR" + loggerFatal loggerLevel = "FATAL" ) var setupLogOnce sync.Once @@ -44,27 +50,26 @@ func SetupLogger() { } func SysLog(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) + logHelper(nil, loggerINFO, s) } func SysLogf(format string, a ...any) { - SysLog(fmt.Sprintf(format, a...)) + logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) } func SysError(s string) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) + logHelper(nil, loggerError, s) } func SysErrorf(format string, a ...any) { - SysError(fmt.Sprintf(format, a...)) + logHelper(nil, loggerError, fmt.Sprintf(format, a...)) } func Debug(ctx context.Context, msg string) { - if config.DebugEnabled { - logHelper(ctx, loggerDEBUG, msg) + if !config.DebugEnabled { + return } + logHelper(ctx, loggerDEBUG, msg) } func Info(ctx context.Context, msg string) { @@ -80,37 +85,65 @@ func Error(ctx context.Context, msg string) { } func Debugf(ctx context.Context, format string, a ...any) { - Debug(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) } func Infof(ctx context.Context, format string, a ...any) { - Info(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) } func Warnf(ctx context.Context, format string, a ...any) { - Warn(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) } func Errorf(ctx context.Context, format string, a ...any) { - Error(ctx, fmt.Sprintf(format, a...)) + logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) } -func logHelper(ctx context.Context, level string, msg string) { +func FatalLog(s string) { + logHelper(nil, loggerFatal, s) +} + +func FatalLogf(format string, a ...any) { + logHelper(nil, loggerFatal, fmt.Sprintf(format, a...)) +} + +func logHelper(ctx context.Context, level loggerLevel, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(helper.RequestIdKey) - if id == nil { - id = helper.GenRequestID() + var logId string + if ctx != nil { + rawLogId := ctx.Value(helper.RequestIdKey) + if rawLogId != nil { + logId = fmt.Sprintf(" | %s", rawLogId.(string)) + } } + lineInfo, funcName := getLineInfo() now := time.Now() - _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + _, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), logId, lineInfo, funcName, msg) SetupLogger() + if level == loggerFatal { + os.Exit(1) + } } -func FatalLog(v ...any) { - t := time.Now() - _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) - os.Exit(1) +func getLineInfo() (string, string) { + funcName := "[unknown] " + pc, file, line, ok := runtime.Caller(3) + if ok { + if fn := runtime.FuncForPC(pc); fn != nil { + parts := strings.Split(fn.Name(), ".") + funcName = "[" + parts[len(parts)-1] + "] " + } + } else { + file = "unknown" + line = 0 + } + parts := strings.Split(file, "one-api/") + if len(parts) > 1 { + file = parts[1] + } + return fmt.Sprintf(" | %s:%d", file, line), funcName } From f95e6b78b837f6b403cbbc7419658c721ccaac04 Mon Sep 17 00:00:00 2001 From: chenzikun Date: Fri, 31 Jan 2025 16:12:59 +0800 Subject: [PATCH 3/4] fix: fix berry copy token (#2041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [bugfix]修复copy问题 * [update]两阶段编译代码 --------- Co-authored-by: zicorn --- web/berry/src/utils/common.js | 349 ++++++++++++++++++---------------- 1 file changed, 181 insertions(+), 168 deletions(-) diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index bd85f8bf..be296122 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -1,247 +1,260 @@ -import { enqueueSnackbar } from 'notistack'; -import { snackbarConstants } from 'constants/SnackbarConstants'; -import { API } from './api'; +import {enqueueSnackbar} from 'notistack'; +import {snackbarConstants} from 'constants/SnackbarConstants'; +import {API} from './api'; export function getSystemName() { - let system_name = localStorage.getItem('system_name'); - if (!system_name) return 'One API'; - return system_name; + let system_name = localStorage.getItem('system_name'); + if (!system_name) return 'One API'; + return system_name; } export function isMobile() { - return window.innerWidth <= 600; + return window.innerWidth <= 600; } // eslint-disable-next-line -export function SnackbarHTMLContent({ htmlContent }) { - return
; +export function SnackbarHTMLContent({htmlContent}) { + return
; } export function getSnackbarOptions(variant) { - let options = snackbarConstants.Common[variant]; - if (isMobile()) { - // 合并 options 和 snackbarConstants.Mobile - options = { ...options, ...snackbarConstants.Mobile }; - } - return options; + let options = snackbarConstants.Common[variant]; + if (isMobile()) { + // 合并 options 和 snackbarConstants.Mobile + options = {...options, ...snackbarConstants.Mobile}; + } + return options; } export function showError(error) { - if (error.message) { - if (error.name === 'AxiosError') { - switch (error.response.status) { - case 429: - enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR')); - break; - case 500: - enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR')); - break; - case 405: - enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO')); - break; - default: - enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR')); - } - return; + if (error.message) { + if (error.name === 'AxiosError') { + switch (error.response.status) { + case 429: + enqueueSnackbar('错误:请求次数过多,请稍后再试!', getSnackbarOptions('ERROR')); + break; + case 500: + enqueueSnackbar('错误:服务器内部错误,请联系管理员!', getSnackbarOptions('ERROR')); + break; + case 405: + enqueueSnackbar('本站仅作演示之用,无服务端!', getSnackbarOptions('INFO')); + break; + default: + enqueueSnackbar('错误:' + error.message, getSnackbarOptions('ERROR')); + } + return; + } + } else { + enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR')); } - } else { - enqueueSnackbar('错误:' + error, getSnackbarOptions('ERROR')); - } } export function showNotice(message, isHTML = false) { - if (isHTML) { - enqueueSnackbar(, getSnackbarOptions('NOTICE')); - } else { - enqueueSnackbar(message, getSnackbarOptions('NOTICE')); - } + if (isHTML) { + enqueueSnackbar(, getSnackbarOptions('NOTICE')); + } else { + enqueueSnackbar(message, getSnackbarOptions('NOTICE')); + } } export function showWarning(message) { - enqueueSnackbar(message, getSnackbarOptions('WARNING')); + enqueueSnackbar(message, getSnackbarOptions('WARNING')); } export function showSuccess(message) { - enqueueSnackbar(message, getSnackbarOptions('SUCCESS')); + enqueueSnackbar(message, getSnackbarOptions('SUCCESS')); } export function showInfo(message) { - enqueueSnackbar(message, getSnackbarOptions('INFO')); + enqueueSnackbar(message, getSnackbarOptions('INFO')); } export async function getOAuthState() { - const res = await API.get('/api/oauth/state'); - const { success, message, data } = res.data; - if (success) { - return data; - } else { - showError(message); - return ''; - } + const res = await API.get('/api/oauth/state'); + const {success, message, data} = res.data; + if (success) { + return data; + } else { + showError(message); + return ''; + } } export async function onGitHubOAuthClicked(github_client_id, openInNewTab = false) { - const state = await getOAuthState(); - if (!state) return; - let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`; - if (openInNewTab) { - window.open(url); - } else { - window.location.href = url; - } + const state = await getOAuthState(); + if (!state) return; + let url = `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`; + if (openInNewTab) { + window.open(url); + } else { + window.location.href = url; + } } export async function onLarkOAuthClicked(lark_client_id) { - const state = await getOAuthState(); - if (!state) return; - let redirect_uri = `${window.location.origin}/oauth/lark`; - window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`); + const state = await getOAuthState(); + if (!state) return; + let redirect_uri = `${window.location.origin}/oauth/lark`; + window.open(`https://accounts.feishu.cn/open-apis/authen/v1/authorize?redirect_uri=${redirect_uri}&client_id=${lark_client_id}&state=${state}`); } export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { - const state = await getOAuthState(); - if (!state) return; - const redirect_uri = `${window.location.origin}/oauth/oidc`; - const response_type = "code"; - const scope = "openid profile email"; - const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; - if (openInNewTab) { - window.open(url); - } else - { - window.location.href = url; - } + const state = await getOAuthState(); + if (!state) return; + const redirect_uri = `${window.location.origin}/oauth/oidc`; + const response_type = "code"; + const scope = "openid profile email"; + const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; + if (openInNewTab) { + window.open(url); + } else { + window.location.href = url; + } } export function isAdmin() { - let user = localStorage.getItem('user'); - if (!user) return false; - user = JSON.parse(user); - return user.role >= 10; + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 10; } export function timestamp2string(timestamp) { - let date = new Date(timestamp * 1000); - let year = date.getFullYear().toString(); - let month = (date.getMonth() + 1).toString(); - let day = date.getDate().toString(); - let hour = date.getHours().toString(); - let minute = date.getMinutes().toString(); - let second = date.getSeconds().toString(); - if (month.length === 1) { - month = '0' + month; - } - if (day.length === 1) { - day = '0' + day; - } - if (hour.length === 1) { - hour = '0' + hour; - } - if (minute.length === 1) { - minute = '0' + minute; - } - if (second.length === 1) { - second = '0' + second; - } - return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second; + let date = new Date(timestamp * 1000); + let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + let minute = date.getMinutes().toString(); + let second = date.getSeconds().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + if (minute.length === 1) { + minute = '0' + minute; + } + if (second.length === 1) { + second = '0' + second; + } + return year + '-' + month + '-' + day + ' ' + hour + ':' + minute + ':' + second; } export function calculateQuota(quota, digits = 2) { - let quotaPerUnit = localStorage.getItem('quota_per_unit'); - quotaPerUnit = parseFloat(quotaPerUnit); + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + quotaPerUnit = parseFloat(quotaPerUnit); - return (quota / quotaPerUnit).toFixed(digits); + return (quota / quotaPerUnit).toFixed(digits); } export function renderQuota(quota, digits = 2) { - let displayInCurrency = localStorage.getItem('display_in_currency'); - displayInCurrency = displayInCurrency === 'true'; - if (displayInCurrency) { - return '$' + calculateQuota(quota, digits); - } - return renderNumber(quota); + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + calculateQuota(quota, digits); + } + return renderNumber(quota); } export const verifyJSON = (str) => { - try { - JSON.parse(str); - } catch (e) { - return false; - } - return true; + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; }; export function renderNumber(num) { - if (num >= 1000000000) { - return (num / 1000000000).toFixed(1) + 'B'; - } else if (num >= 1000000) { - return (num / 1000000).toFixed(1) + 'M'; - } else if (num >= 10000) { - return (num / 1000).toFixed(1) + 'k'; - } else { - return num; - } + if (num >= 1000000000) { + return (num / 1000000000).toFixed(1) + 'B'; + } else if (num >= 1000000) { + return (num / 1000000).toFixed(1) + 'M'; + } else if (num >= 10000) { + return (num / 1000).toFixed(1) + 'k'; + } else { + return num; + } } export function renderQuotaWithPrompt(quota, digits) { - let displayInCurrency = localStorage.getItem('display_in_currency'); - displayInCurrency = displayInCurrency === 'true'; - if (displayInCurrency) { - return `(等价金额:${renderQuota(quota, digits)})`; - } - return ''; + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return `(等价金额:${renderQuota(quota, digits)})`; + } + return ''; } export function downloadTextAsFile(text, filename) { - let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); - let url = URL.createObjectURL(blob); - let a = document.createElement('a'); - a.href = url; - a.download = filename; - a.click(); + let blob = new Blob([text], {type: 'text/plain;charset=utf-8'}); + let url = URL.createObjectURL(blob); + let a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); } export function removeTrailingSlash(url) { - if (url.endsWith('/')) { - return url.slice(0, -1); - } else { - return url; - } + if (url.endsWith('/')) { + return url.slice(0, -1); + } else { + return url; + } } let channelModels = undefined; + export async function loadChannelModels() { - const res = await API.get('/api/models'); - const { success, data } = res.data; - if (!success) { - return; - } - channelModels = data; - localStorage.setItem('channel_models', JSON.stringify(data)); + const res = await API.get('/api/models'); + const {success, data} = res.data; + if (!success) { + return; + } + channelModels = data; + localStorage.setItem('channel_models', JSON.stringify(data)); } export function getChannelModels(type) { - if (channelModels !== undefined && type in channelModels) { - return channelModels[type]; - } - let models = localStorage.getItem('channel_models'); - if (!models) { + if (channelModels !== undefined && type in channelModels) { + return channelModels[type]; + } + let models = localStorage.getItem('channel_models'); + if (!models) { + return []; + } + channelModels = JSON.parse(models); + if (type in channelModels) { + return channelModels[type]; + } return []; - } - channelModels = JSON.parse(models); - if (type in channelModels) { - return channelModels[type]; - } - return []; } export function copy(text, name = '') { - try { - navigator.clipboard.writeText(text); - } catch (error) { - text = `复制${name}失败,请手动复制:

${text}`; - enqueueSnackbar(, getSnackbarOptions('COPY')); - return; - } - showSuccess(`复制${name}成功!`); + if (navigator.clipboard && navigator.clipboard.writeText) { + navigator.clipboard.writeText(text).then(() => { + showNotice(`复制${name}成功!`, true); + }, () => { + text = `复制${name}失败,请手动复制:

${text}`; + enqueueSnackbar(, getSnackbarOptions('COPY')); + }); + } else { + const textArea = document.createElement("textarea"); + textArea.value = text; + document.body.appendChild(textArea); + textArea.select(); + try { + document.execCommand('copy'); + showNotice(`复制${name}成功!`, true); + } catch (err) { + text = `复制${name}失败,请手动复制:

${text}`; + enqueueSnackbar(, getSnackbarOptions('COPY')); + } + document.body.removeChild(textArea); + } } From 09911a301df4aac3fc8271cbd3b1eb7539e97c94 Mon Sep 17 00:00:00 2001 From: Fennng Date: Fri, 31 Jan 2025 16:48:02 +0800 Subject: [PATCH 4/4] feat: support hunyuan-embedding (#2035) * feat: support hunyuan-embedding * chore: improve implementation --------- Co-authored-by: LUO Feng Co-authored-by: JustSong --- relay/adaptor/tencent/adaptor.go | 32 ++++++++++---- relay/adaptor/tencent/constants.go | 1 + relay/adaptor/tencent/main.go | 68 ++++++++++++++++++++++++++++-- relay/adaptor/tencent/model.go | 46 +++++++++++++++----- 4 files changed, 126 insertions(+), 21 deletions(-) diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 0de92d4a..b20d4279 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -2,16 +2,19 @@ package tencent import ( "errors" + "io" + "net/http" + "strconv" + "strings" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" + "github.com/songquanpeng/one-api/relay/relaymode" ) // https://cloud.tencent.com/document/api/1729/101837 @@ -52,10 +55,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if err != nil { return nil, err } - tencentRequest := ConvertRequest(*request) + var convertedRequest any + switch relayMode { + case relaymode.Embeddings: + a.Action = "GetEmbedding" + convertedRequest = ConvertEmbeddingRequest(*request) + default: + a.Action = "ChatCompletions" + convertedRequest = ConvertRequest(*request) + } // we have to calculate the sign here - a.Sign = GetSign(*tencentRequest, a, secretId, secretKey) - return tencentRequest, nil + a.Sign = GetSign(convertedRequest, a, secretId, secretKey) + return convertedRequest, nil } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { @@ -75,7 +86,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, responseText = StreamHandler(c, resp) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { - err, usage = Handler(c, resp) + switch meta.Mode { + case relaymode.Embeddings: + err, usage = EmbeddingHandler(c, resp) + default: + err, usage = Handler(c, resp) + } } return } diff --git a/relay/adaptor/tencent/constants.go b/relay/adaptor/tencent/constants.go index e8631e5f..7997bfd6 100644 --- a/relay/adaptor/tencent/constants.go +++ b/relay/adaptor/tencent/constants.go @@ -6,4 +6,5 @@ var ModelList = []string{ "hunyuan-standard-256K", "hunyuan-pro", "hunyuan-vision", + "hunyuan-embedding", } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 827c8a46..8bf8e469 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strconv" @@ -16,11 +15,14 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -44,8 +46,68 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { } } +func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { + return &EmbeddingRequest{ + InputList: request.ParseInput(), + } +} + +func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var tencentResponseP EmbeddingResponseP + err := json.NewDecoder(resp.Body).Decode(&tencentResponseP) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + tencentResponse := tencentResponseP.Response + if tencentResponse.Error.Code != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: tencentResponse.Error.Message, + Code: tencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + requestModel := c.GetString(ctxkey.RequestModel) + fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse) + fullTextResponse.Model = requestModel + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse { + openAIEmbeddingResponse := openai.EmbeddingResponse{ + Object: "list", + Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)), + Model: "hunyuan-embedding", + Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens}, + } + + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ + Id: response.ReqID, Object: "chat.completion", Created: helper.GetTimestamp(), Usage: model.Usage{ @@ -148,7 +210,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } TencentResponse = responseP.Response - if TencentResponse.Error.Code != 0 { + if TencentResponse.Error.Code != "" { return &model.ErrorWithStatusCode{ Error: model.Error{ Message: TencentResponse.Error.Message, @@ -195,7 +257,7 @@ func hmacSha256(s, key string) string { return string(hashed.Sum(nil)) } -func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string { +func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { // build canonical request string host := "hunyuan.tencentcloudapi.com" httpRequestMethod := "POST" diff --git a/relay/adaptor/tencent/model.go b/relay/adaptor/tencent/model.go index fb97724e..fda6c6cc 100644 --- a/relay/adaptor/tencent/model.go +++ b/relay/adaptor/tencent/model.go @@ -35,16 +35,16 @@ type ChatRequest struct { // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 // 3. 非必要不建议使用,不合理的取值会影响效果。 - TopP *float64 `json:"TopP"` + TopP *float64 `json:"TopP,omitempty"` // 说明: // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 // 3. 非必要不建议使用,不合理的取值会影响效果。 - Temperature *float64 `json:"Temperature"` + Temperature *float64 `json:"Temperature,omitempty"` } type Error struct { - Code int `json:"Code"` + Code string `json:"Code"` Message string `json:"Message"` } @@ -61,15 +61,41 @@ type ResponseChoices struct { } type ChatResponse struct { - Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 - Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 - Id string `json:"Id,omitempty"` // 会话 id - Usage Usage `json:"Usage,omitempty"` // token 数量 - Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 - Note string `json:"Note,omitempty"` // 注释 - ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 + Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 + Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 + Id string `json:"Id,omitempty"` // 会话 id + Usage Usage `json:"Usage,omitempty"` // token 数量 + Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"Note,omitempty"` // 注释 + ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 } type ChatResponseP struct { Response ChatResponse `json:"Response,omitempty"` } + +type EmbeddingRequest struct { + InputList []string `json:"InputList"` +} + +type EmbeddingData struct { + Embedding []float64 `json:"Embedding"` + Index int `json:"Index"` + Object string `json:"Object"` +} + +type EmbeddingUsage struct { + PromptTokens int `json:"PromptTokens"` + TotalTokens int `json:"TotalTokens"` +} + +type EmbeddingResponse struct { + Data []EmbeddingData `json:"Data"` + EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"` + RequestId string `json:"RequestId,omitempty"` + Error Error `json:"Error,omitempty"` +} + +type EmbeddingResponseP struct { + Response EmbeddingResponse `json:"Response,omitempty"` +}