2.5版本 增加dall-e 优化mj 对接mj-plus

This commit is contained in:
小易
2024-02-04 18:51:37 +08:00
parent 822ff0a51d
commit 2ca3c164a0
57 changed files with 12239 additions and 7693 deletions

BIN
service/src/.DS_Store vendored

Binary file not shown.

View File

@@ -2,8 +2,8 @@ import { UploadService } from './../upload/upload.service';
import { UserService } from './../user/user.service';
import { ConfigService } from 'nestjs-config';
import { HttpException, HttpStatus, Injectable, OnModuleInit, Logger } from '@nestjs/common';
import type { ChatGPTAPIOptions, ChatMessage, SendMessageOptions } from 'chatgpt-nine-ai';
import { Request, Response } from 'express';
import type { ChatGPTAPIOptions, ChatMessage, SendMessageOptions } from 'chatgpt-ai-web';
import e, { Request, Response } from 'express';
import { OpenAiErrorCodeMessage } from '@/common/constants/errorMessage.constant';
import {
compileNetwork,
@@ -97,7 +97,7 @@ export class ChatgptService implements OnModuleInit {
};
async onModuleInit() {
let chatgpt = await importDynamic('chatgpt-nine-ai');
let chatgpt = await importDynamic('chatgpt-ai-web');
let KeyvRedis = await importDynamic('@keyv/redis');
let Keyv = await importDynamic('keyv');
chatgpt = chatgpt?.default ? chatgpt.default : chatgpt;
@@ -165,7 +165,7 @@ export class ChatgptService implements OnModuleInit {
/* 不同场景会变更其信息 */
let setSystemMessage = systemMessage;
const { parentMessageId } = options;
const { prompt ,imageUrl,model:activeModel} = body;
const { prompt, imageUrl, model: activeModel } = body;
const { groupId, usingNetwork } = options;
// const { model = 3 } = options;
/* 获取当前对话组的详细配置信息 */
@@ -184,7 +184,7 @@ export class ChatgptService implements OnModuleInit {
throw new HttpException('当前流程所需要的模型已被管理员下架、请联系管理员上架专属模型!', HttpStatus.BAD_REQUEST);
}
const { deduct, isTokenBased, deductType, key: modelKey, secret, modelName, id: keyId, accessToken } = currentRequestModelKey;
const { deduct, isTokenBased, tokenFeeRatio, deductType, key: modelKey, secret, modelName, id: keyId, accessToken } = currentRequestModelKey;
/* 用户状态检测 */
await this.userService.checkUserStatus(req.user);
/* 用户余额检测 */
@@ -260,7 +260,7 @@ export class ChatgptService implements OnModuleInit {
userId: req.user.id,
type: DeductionKey.CHAT_TYPE,
prompt,
imageUrl,
imageUrl:response?.imageUrl,
activeModel,
answer: '',
promptTokens: prompt_tokens,
@@ -307,7 +307,7 @@ export class ChatgptService implements OnModuleInit {
/* 当用户回答一般停止时 也需要扣费 */
let charge = deduct;
if (isTokenBased === true) {
charge = deduct * total_tokens;
charge = Math.ceil((deduct * total_tokens) / tokenFeeRatio);
}
await this.userBalanceService.deductFromBalance(req.user.id, `model${deductType === 1 ? 3 : 4}`, charge, total_tokens);
});
@@ -320,11 +320,11 @@ export class ChatgptService implements OnModuleInit {
const { context: messagesHistory } = await this.nineStore.buildMessageFromParentMessageId(usingNetwork ? netWorkPrompt : prompt, {
parentMessageId,
systemMessage,
imageUrl,
activeModel,
maxModelToken: maxToken,
maxResponseTokens: maxTokenRes,
maxRounds: addOneIfOdd(rounds),
imageUrl,
activeModel,
});
let firstChunk = true;
response = await sendMessageFromOpenAi(messagesHistory, {
@@ -332,8 +332,9 @@ export class ChatgptService implements OnModuleInit {
maxTokenRes,
apiKey: modelKey,
model,
imageUrl,
prompt,
activeModel,
imageUrl,
temperature,
proxyUrl: proxyResUrl,
onProgress: (chat) => {
@@ -341,7 +342,7 @@ export class ChatgptService implements OnModuleInit {
lastChat = chat;
firstChunk = false;
},
});
},this.uploadService);
isSuccess = true;
}
@@ -385,7 +386,6 @@ export class ChatgptService implements OnModuleInit {
isSuccess = true;
}
/* 分别将本次用户输入的 和 机器人返回的分两次存入到 store */
const userMessageData: MessageInfo = {
id: this.nineStore.getUuid(),
text: prompt,
@@ -407,7 +407,8 @@ export class ChatgptService implements OnModuleInit {
text: response.text,
role: 'assistant',
name: undefined,
usage: response.usage,
usage: response?.usage,
imageUrl: response?.imageUrl,
parentMessageId: userMessageData.id,
conversationId: response?.conversationId,
};
@@ -415,7 +416,6 @@ export class ChatgptService implements OnModuleInit {
await this.nineStore.setData(assistantMessageData);
othersInfo = { model, parentMessageId: userMessageData.id };
/* 回答完毕 */
} else {
const { key, maxToken, maxTokenRes, proxyResUrl } = await this.formatModelToken(currentRequestModelKey);
const { parentMessageId, completionParams, systemMessage } = mergedOptions;
@@ -431,17 +431,22 @@ export class ChatgptService implements OnModuleInit {
temperature,
proxyUrl: proxyResUrl,
onProgress: null,
prompt,
});
}
/* 统一最终输出格式 */
const formatResponse = await unifiedFormattingResponse(keyType, response, othersInfo);
const { prompt_tokens = 0, completion_tokens = 0, total_tokens = 0 } = formatResponse.usage;
let usage = null;
let formatResponse = null;
if (model.includes('dall')) {
usage = response.detail?.usage || { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 };
} else {
formatResponse = await unifiedFormattingResponse(keyType, response, othersInfo);
}
const { prompt_tokens, completion_tokens, total_tokens } = model.includes('dall') ? usage : formatResponse.usage;
/* 区分扣除普通还是高级余额 model3: 普通余额 model4 高级余额 */
let charge = deduct;
if (isTokenBased === true) {
charge = deduct * total_tokens;
charge = Math.ceil((deduct * total_tokens) / tokenFeeRatio);
}
await this.userBalanceService.deductFromBalance(req.user.id, `model${deductType === 1 ? 3 : 4}`, charge, total_tokens);
@@ -457,13 +462,13 @@ export class ChatgptService implements OnModuleInit {
userId: req.user.id,
type: DeductionKey.CHAT_TYPE,
prompt,
imageUrl,
imageUrl: response?.imageUrl,
activeModel,
answer: '',
promptTokens: prompt_tokens,
completionTokens: 0,
totalTokens: total_tokens,
model: formatResponse.model,
model: model,
role: 'user',
groupId,
requestOptions: JSON.stringify({
@@ -479,7 +484,8 @@ export class ChatgptService implements OnModuleInit {
userId: req.user.id,
type: DeductionKey.CHAT_TYPE,
prompt: prompt,
answer: formatResponse?.text,
imageUrl: response?.imageUrl,
answer: response.text,
promptTokens: prompt_tokens,
completionTokens: completion_tokens,
totalTokens: total_tokens,
@@ -501,7 +507,7 @@ export class ChatgptService implements OnModuleInit {
}),
});
Logger.debug(
`本次调用: ${req.user.id} model: ${model} key -> ${key}, 模型名称: ${modelName}, 最大回复token: ${maxResponseTokens}`,
`户ID: ${req.user.id} 模型名称: ${modelName}-${activeModel}, 消耗token: ${total_tokens}, 消耗积分: ${charge}`,
'ChatgptService',
);
const userBalance = await this.userBalanceService.queryUserBalance(req.user.id);
@@ -599,7 +605,7 @@ export class ChatgptService implements OnModuleInit {
await this.userBalanceService.validateBalance(req, 'mjDraw', money);
let images = [];
/* 从3的卡池随机拿一个key */
const detailKeyInfo = await this.modelsService.getRandomDrawKey();
const detailKeyInfo = await this.modelsService.getCurrentModelKeyInfo('dall-e-3');
const keyId = detailKeyInfo?.id;
const { key, proxyResUrl } = await this.formatModelToken(detailKeyInfo);
Logger.log(`draw paompt info <==**==> ${body.prompt}, key ===> ${key}`, 'DrawService');

View File

@@ -1,6 +1,9 @@
import axios, { AxiosRequestConfig, AxiosResponse } from 'axios';
import { get_encoding } from '@dqbd/tiktoken'
import { removeSpecialCharacters } from '@/common/utils';
import { ConsoleLogger, HttpException, HttpStatus, Logger } from '@nestjs/common';
import * as uuid from 'uuid';
const tokenizer = get_encoding('cl100k_base')
@@ -11,96 +14,166 @@ interface SendMessageResult {
detail?: any;
}
function getFullUrl(proxyUrl){
function getFullUrl(proxyUrl) {
const processedUrl = proxyUrl.endsWith('/') ? proxyUrl.slice(0, -1) : proxyUrl;
const baseUrl = processedUrl || 'https://api.openai.com'
return `${baseUrl}/v1/chat/completions`
}
export function sendMessageFromOpenAi(messagesHistory, inputs ){
const { onProgress, maxToken, apiKey, model, temperature = 0.95, proxyUrl } = inputs
console.log('current request options: ',apiKey, model, maxToken, proxyUrl );
const max_tokens = compilerToken(model, maxToken)
const options: AxiosRequestConfig = {
method: 'POST',
url: getFullUrl(proxyUrl),
responseType: 'stream',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${removeSpecialCharacters(apiKey)}`,
},
data: {
max_tokens,
stream: true,
temperature,
model,
messages: messagesHistory
},
};
const prompt = messagesHistory[messagesHistory.length-1]?.content
return new Promise(async (resolve, reject) =>{
export async function sendMessageFromOpenAi(messagesHistory, inputs, uploadService?) {
const { onProgress, maxToken, apiKey, model, temperature = 0.8, proxyUrl, prompt } = inputs
if (model.includes('dall')) {
let result: any = { text: '', imageUrl: '' };
try {
const options: AxiosRequestConfig = {
method: 'POST',
url: `${proxyUrl}/v1/images/generations`,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
data: {
prompt: prompt,
model: model,
response_format: 'b64_json'
},
}
const response: any = await axios(options);
const stream = response.data;
let result: any = { text: '' };
stream.on('data', (chunk) => {
const splitArr = chunk.toString().split('\n\n').filter((line) => line.trim() !== '');
for (const line of splitArr) {
const data = line.replace('data:', '');
let ISEND = false;
try {
ISEND = JSON.parse(data).choices[0].finish_reason === 'stop';
} catch (error) {
ISEND = false;
}
/* 如果结束 返回所有 */
if (data === '[DONE]' || ISEND) {
result.text = result.text.trim();
return result;
}
try {
const parsedData = JSON.parse(data);
if (parsedData.id) {
result.id = parsedData.id;
}
if (parsedData.choices?.length) {
const delta = parsedData.choices[0].delta;
result.delta = delta.content;
if (delta?.content) result.text += delta.content;
if (delta.role) {
result.role = delta.role;
}
result.detail = parsedData;
}
onProgress && onProgress({text:result.text})
} catch (error) {
console.log('parse Error', data )
}
}
});
stream.on('end', () => {
// 手动计算token
if(result.detail && result.text){
const promptTokens = getTokenCount(prompt)
const completionTokens = getTokenCount(result.text)
result.detail.usage = {
prompt_tokens: promptTokens,
completion_tokens: completionTokens ,
total_tokens: promptTokens + completionTokens,
estimated: true
}
}
return resolve(result);
});
const { b64_json, revised_prompt } = response.data.data[0]
const buffer = Buffer.from(b64_json, 'base64');
let imgUrl = '';
try {
const filename = uuid.v4().slice(0, 10) + '.png';
Logger.debug(`------> 开始上传图片!!!`, 'MidjourneyService');
const buffer = Buffer.from(b64_json, 'base64');
// imgUrl = await uploadService.uploadFileFromUrl({ filename, url })
imgUrl = await uploadService.uploadFile({ filename, buffer });
Logger.debug(`图片上传成功URL: ${imgUrl}`, 'MidjourneyService');
} catch (error) {
Logger.error(`上传图片过程中出现错误: ${error}`, 'MidjourneyService');
}
result.imageUrl = imgUrl
result.text = revised_prompt;
onProgress && onProgress({ text: result.text })
return result;
} catch (error) {
reject(error)
const status = error?.response?.status || 500;
console.log('openai-draw error: ', JSON.stringify(error), status);
const message = error?.response?.data?.error?.message;
if (status === 429) {
result.text = '当前请求已过载、请稍等会儿再试试吧!';
return result;
}
if (status === 400 && message.includes('This request has been blocked by our content filters')) {
result.text = '您的请求已被系统拒绝。您的提示可能存在一些非法的文本。';
return result;
}
if (status === 400 && message.includes('Billing hard limit has been reached')) {
result.text = '当前模型key已被封禁、已冻结当前调用Key、尝试重新对话试试吧';
return result;
}
if (status === 500) {
result.text = '绘制图片失败,请检查你的提示词是否有非法描述!';
return result;
}
if (status === 401) {
result.text = '绘制图片失败,此次绘画被拒绝了!';
return result;
}
result.text = '绘制图片失败,请稍后试试吧!';
return result;
}
})
} else {
let result: any = { text: '' };
const options: AxiosRequestConfig = {
method: 'POST',
url: getFullUrl(proxyUrl),
responseType: 'stream',
headers: {
'Content-Type': 'application/json',
Accept: "application/json",
Authorization: `Bearer ${apiKey}`,
},
data: {
stream: true,
temperature,
model,
messages: messagesHistory,
},
};
if (model === 'gpt-4-vision-preview') {
options.data.max_tokens = 2048;
}
return new Promise(async (resolve, reject) => {
try {
const response: any = await axios(options);
const stream = response.data;
stream.on('data', (chunk) => {
const splitArr = chunk.toString().split('\n\n').filter((line) => line.trim() !== '');
for (const line of splitArr) {
const data = line.replace('data:', '');
let ISEND = false;
try {
ISEND = JSON.parse(data).choices[0].finish_reason === 'stop';
} catch (error) {
ISEND = false;
}
/* 如果结束 返回所有 */
if (ISEND) {
result.text = result.text.trim();
return result;
}
try {
if (data !== " [DONE]" && data !== "[DONE]" && data != "[DONE] ") {
const parsedData = JSON.parse(data);
if (parsedData.id) {
result.id = parsedData.id;
}
if (parsedData.choices?.length) {
const delta = parsedData.choices[0].delta;
result.delta = delta.content;
if (delta?.content) result.text += delta.content;
if (delta.role) {
result.role = delta.role;
}
result.detail = parsedData;
}
onProgress && onProgress({ text: result.text })
}
} catch (error) {
console.log('parse Error', data)
}
}
});
let totalText = '';
messagesHistory.forEach(message => {
totalText += message.content + ' ';
});
stream.on('end', () => {
// 手动计算token
if (result.detail && result.text) {
const promptTokens = getTokenCount(totalText)
const completionTokens = getTokenCount(result.text)
result.detail.usage = {
prompt_tokens: promptTokens,
completion_tokens: completionTokens,
total_tokens: promptTokens + completionTokens,
estimated: true
}
}
return resolve(result);
});
} catch (error) {
reject(error)
}
})
}
}
export function getTokenCount(text: string) {
if (!text) return 0;
// 确保text是字符串类型
@@ -111,28 +184,4 @@ export function getTokenCount(text: string) {
return tokenizer.encode(text).length
}
function compilerToken(model, maxToken){
let max = 0
/* 3.5 */
if(model.includes(3.5)){
max = maxToken > 4096 ? 4096 : maxToken
}
/* 4.0 */
if(model.includes('gpt-4')){
max = maxToken > 8192 ? 8192 : maxToken
}
/* 4.0 preview */
if(model.includes('preview')){
max = maxToken > 4096 ? 4096 : maxToken
}
/* 4.0 32k */
if(model.includes('32k')){
max = maxToken > 32768 ? 32768 : maxToken
}
return max
}

View File

@@ -2,6 +2,7 @@ import Keyv from 'keyv';
import { v4 as uuidv4 } from 'uuid';
import { get_encoding } from '@dqbd/tiktoken';
import { Logger } from '@nestjs/common';
import { includes } from 'lodash';
const tokenizer = get_encoding('cl100k_base');
@@ -96,7 +97,7 @@ export class NineStore implements NineStoreInterface {
let nextNumTokensEstimate = 0;
// messages.push({ role: 'system', content: systemMessage, name })
if (systemMessage) {
const specialModels = ['gemini-pro', 'ERNIE', 'qwen', 'SparkDesk', 'hunyuan'];
const specialModels = ['gemini-pro', 'ERNIE','hunyuan'];
const isSpecialModel = activeModel && specialModels.some((specialModel) => activeModel.includes(specialModel));
if (isSpecialModel) {
messages.push({ role: 'user', content: systemMessage, name });
@@ -146,7 +147,7 @@ export class NineStore implements NineStoreInterface {
let content = text; // 默认情况下使用text作为content
// 特别处理包含 imageUrl 的消息
if (role === 'user' && imageUrl) {
if (imageUrl) {
if (activeModel === 'gpt-4-vision-preview') {
content = [
{ type: 'text', text: text },

View File

@@ -13,7 +13,7 @@ interface UserInfo {
@Injectable()
export class DatabaseService implements OnModuleInit {
constructor(private connection: Connection) {}
constructor(private connection: Connection) { }
async onModuleInit() {
await this.checkSuperAdmin();
await this.checkSiteBaseConfig();
@@ -23,7 +23,7 @@ export class DatabaseService implements OnModuleInit {
async checkSuperAdmin() {
const user = await this.connection.query(`SELECT * FROM users WHERE role = 'super'`);
if (!user || user.length === 0) {
const superPassword = bcrypt.hashSync('123456', 10);
const superPassword = bcrypt.hashSync('123456', 10); //初始密码
const adminPassword = bcrypt.hashSync('123456', 10);
const superEmail = 'default@cooper.com';
const adminEmail = 'defaultAdmin@cooper.com';
@@ -44,7 +44,7 @@ export class DatabaseService implements OnModuleInit {
const userId = user.insertId;
const balance = await this.connection.query(`INSERT INTO balance (userId, balance, usesLeft, paintCount) VALUES ('${userId}', 0, 1000, 100)`);
Logger.log(
`初始化创建${role}用户成功、用户名为[${username}]、初始密码为[${username === 'super' ? '123456' : '123456'}] ==============> 请注意查阅`,
`初始化创建${role}用户成功、用户名为[${username}]、初始密码为[${username === 'super' ? 'nine-super' : '123456'}] ==============> 请注意查阅`,
'DatabaseService',
);
} catch (error) {
@@ -68,17 +68,7 @@ export class DatabaseService implements OnModuleInit {
/* 创建基础的网站数据 */
async createBaseSiteConfig() {
try {
const code = `
<script>
var _hmt = _hmt || [];
(function() {
var hm = document.createElement("script");
hm.src = "https://hm.baidu.com/hm.js?cb8c9a3bcadbc200e950b05f9c61a385";
var s = document.getElementsByTagName("script")[0];
s.parentNode.insertBefore(hm, s);
})();
</script>
`;
const code = ``;
const noticeInfo = `
#### YiAi 欢迎您
@@ -88,21 +78,21 @@ export class DatabaseService implements OnModuleInit {
`;
const defaultConfig = [
{ configKey: 'siteName', configVal: 'Nine Ai', public: 1, encry: 0 },
{ configKey: 'qqNumber', configVal: '840814166', public: 1, encry: 0 },
{ configKey: 'vxNumber', configVal: 'wangpanzhu321', public: 1, encry: 0 },
{ configKey: 'siteName', configVal: 'Yi Ai', public: 1, encry: 0 },
{ configKey: 'qqNumber', configVal: '805239273', public: 1, encry: 0 },
{ configKey: 'vxNumber', configVal: 'HelloWordYi819', public: 1, encry: 0 },
{ configKey: 'robotAvatar', configVal: '', public: 1, encry: 0 },
{
configKey: 'userDefautlAvatar',
configVal: 'https://public-1300678944.cos.ap-shanghai.myqcloud.com/blog/1682571295452image.png',
configVal: '',
public: 0,
encry: 0,
},
{ configKey: 'baiduCode', configVal: code, public: 1, encry: 0 },
{ configKey: 'baiduSiteId', configVal: '19024441', public: 0, encry: 0 },
{ configKey: 'baiduSiteId', configVal: '', public: 0, encry: 0 },
{
configKey: 'baiduToken',
configVal: '121.a1600b9b60910feea2ef627ea9776a6f.YGP_CWCOA2lNcIGJ27BwXGxa6nZhBQyLUS4XVaD.TWt9TA',
configVal: '',
public: 0,
encry: 0,
},
@@ -110,25 +100,25 @@ export class DatabaseService implements OnModuleInit {
{ configKey: 'openaiBaseUrl', configVal: 'https://api.openai.com', public: 0, encry: 0 },
{ configKey: 'noticeInfo', configVal: noticeInfo, public: 1, encry: 0 },
{ configKey: 'registerVerifyEmailTitle', configVal: 'NineTeam团队账号验证', public: 0, encry: 0 },
{ configKey: 'registerVerifyEmailTitle', configVal: 'Yi Ai团队账号验证', public: 0, encry: 0 },
{
configKey: 'registerVerifyEmailDesc',
configVal: '欢迎使用Nine Team团队的产品服务,请在五分钟内完成你的账号激活,点击以下按钮激活您的账号,',
configVal: '欢迎使用Yi Ai团队的产品服务,请在五分钟内完成你的账号激活,点击以下按钮激活您的账号,',
public: 0,
encry: 0,
},
{ configKey: 'registerVerifyEmailFrom', configVal: 'NineTeam团队', public: 0, encry: 0 },
{ configKey: 'registerVerifyEmailFrom', configVal: 'Yi Ai团队', public: 0, encry: 0 },
{ configKey: 'registerVerifyExpir', configVal: '1800', public: 0, encry: 0 },
{ configKey: 'registerSuccessEmailTitle', configVal: 'NineTeam团队账号激活成功', public: 0, encry: 0 },
{ configKey: 'registerSuccessEmailTeamName', configVal: 'NineTeam团队', public: 0, encry: 0 },
{ configKey: 'registerSuccessEmailTitle', configVal: 'Yi Ai账号激活成功', public: 0, encry: 0 },
{ configKey: 'registerSuccessEmailTeamName', configVal: 'Yi Ai', public: 0, encry: 0 },
{
configKey: 'registerSuccessEmaileAppend',
configVal: ',请妥善保管您的账号,我们将为您赠送50次对话额度和5次绘画额度、祝您使用愉快',
configVal: ',请妥善保管您的账号,祝您使用愉快',
public: 0,
encry: 0,
},
{ configKey: 'registerFailEmailTitle', configVal: 'NineTeam账号激活失败', public: 0, encry: 0 },
{ configKey: 'registerFailEmailTeamName', configVal: 'NineTeam团队', public: 0, encry: 0 },
{ configKey: 'registerFailEmailTitle', configVal: 'Yi Ai账号激活失败', public: 0, encry: 0 },
{ configKey: 'registerFailEmailTeamName', configVal: 'Yi Ai团队', public: 0, encry: 0 },
/* 注册默认设置 */
{ configKey: 'registerSendStatus', configVal: '1', public: 1, encry: 0 },
{ configKey: 'registerSendModel3Count', configVal: '30', public: 1, encry: 0 },
@@ -136,16 +126,16 @@ export class DatabaseService implements OnModuleInit {
{ configKey: 'registerSendDrawMjCount', configVal: '3', public: 1, encry: 0 },
{ configKey: 'firstRegisterSendStatus', configVal: '1', public: 1, encry: 0 },
{ configKey: 'firstRegisterSendRank', configVal: '500', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendModel3Count', configVal: '20', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendModel4Count', configVal: '2', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendDrawMjCount', configVal: '3', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendModel3Count', configVal: '10', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendModel4Count', configVal: '10', public: 1, encry: 0 },
{ configKey: 'firstRregisterSendDrawMjCount', configVal: '10', public: 1, encry: 0 },
{ configKey: 'inviteSendStatus', configVal: '1', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendModel3Count', configVal: '30', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendModel4Count', configVal: '3', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendDrawMjCount', configVal: '1', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendModel3Count', configVal: '0', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendModel4Count', configVal: '0', public: 1, encry: 0 },
{ configKey: 'inviteGiveSendDrawMjCount', configVal: '0', public: 1, encry: 0 },
{ configKey: 'invitedGuestSendModel3Count', configVal: '10', public: 1, encry: 0 },
{ configKey: 'invitedGuestSendModel4Count', configVal: '1', public: 1, encry: 0 },
{ configKey: 'invitedGuestSendDrawMjCount', configVal: '1', public: 1, encry: 0 },
{ configKey: 'invitedGuestSendModel4Count', configVal: '10', public: 1, encry: 0 },
{ configKey: 'invitedGuestSendDrawMjCount', configVal: '10', public: 1, encry: 0 },
{ configKey: 'isVerifyEmail', configVal: '1', public: 1, encry: 0 },
];

View File

@@ -22,9 +22,6 @@ export class MidjourneyEntity extends BaseEntity {
@Column({ comment: '垫图图片 + 绘画描述词 + 额外参数 = 完整的prompt', type: 'text' })
fullPrompt: string;
@Column({ comment: '随机产生的绘画ID用于拿取比对结果' })
randomDrawId: string;
@Column({ comment: '当前绘制任务的进度', nullable: true })
progress: number;
@@ -35,7 +32,7 @@ export class MidjourneyEntity extends BaseEntity {
status: number;
@Column({ comment: 'mj绘画的动作、绘图、放大、变换、图生图' })
action: number;
action: string;
@Column({ comment: '一组图片的第几张、放大或者变换的时候需要使用', nullable: true })
orderId: number;
@@ -43,14 +40,17 @@ export class MidjourneyEntity extends BaseEntity {
@Column({ comment: '是否推荐0: 默认不推荐 1: 推荐', nullable: true, default: 0 })
rec: number;
@Column({ comment: '对图片操作的', nullable: true })
customId: string;
@Column({ comment: '绘画的ID每条不一样', nullable: true })
message_id: string;
drawId: string;
@Column({ comment: '图片放大或者变体的ID', nullable: true })
custom_id: string;
@Column({ comment: '图片链接', nullable: true, type: 'text' })
drawUrl: string;
@Column({ comment: '图片信息尺寸', nullable: true, type: 'text' })
fileInfo: string;
@Column({ comment: '图片比例', nullable: true, type: 'text' })
drawRatio: string;
@Column({ comment: '扩展参数', nullable: true, type: 'text' })
extend: string;
@@ -60,4 +60,5 @@ export class MidjourneyEntity extends BaseEntity {
@Column({ comment: '是否存入了图片到配置的储存项 配置了则存储 不配置地址则是源地址', default: true })
isSaveImg: boolean;
messageId: any;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
import { AddBadWordDto } from '../../badwords/dto/addBadWords.dto';
import { AddBadWordDto } from './../../badwords/dto/addBadWords.dto';
import { IsNotEmpty, MinLength, MaxLength, IsEmail, IsOptional, IsNumber } from 'class-validator';
import { ApiProperty } from '@nestjs/swagger';
@@ -27,6 +27,9 @@ export class SetModelDto {
@ApiProperty({ example: 1, description: 'key的权重' })
keyWeight: number;
@ApiProperty({ example: 1, description: '模型排序' })
modelOrder: number;
@ApiProperty({ example: 4096, description: '模型支持的最大TOken数量', required: true })
maxModelTokens: number;
@@ -53,9 +56,10 @@ export class SetModelDto {
@ApiProperty({ example: true, description: '是否设置为绘画Key', required: false })
isDraw: boolean;
//设置token计费
@ApiProperty({ example: true, description: '是否使用token计费', required: false })
isTokenBased: boolean;
@ApiProperty({ example: true, description: 'token计费比例', required: false })
tokenFeeRatio: number;
}

View File

@@ -42,4 +42,7 @@ export class ModelsTypeEntity extends BaseEntity {
@Column({ comment: '是否为特殊模型、可以提供联想翻译、思维导图等特殊操作', default: 0 })
isUseTool: boolean;
@Column({ comment: '模型排序', default: 1 })
modelOrder: number;
}

View File

@@ -69,4 +69,8 @@ export class ModelsEntity extends BaseEntity {
@Column({ comment: 'token计费比例', default: 0 })
tokenFeeRatio: number;
@Column({ comment: 'key权重', default: 1 })
modelOrder: number;
}

View File

@@ -5,7 +5,7 @@ import { ModelsEntity } from './models.entity';
import { SetModelDto } from './dto/setModel.dto';
import { QueryModelDto } from './dto/queryModel.dto';
import { ModelsMapCn } from '@/common/constants/status.constant';
import { getAccessToken } from '../chatgpt/baidu';
// import { getAccessToken } from '../chatgpt/baidu';
import { getRandomItemFromArray, hideString } from '@/common/utils';
import { ModelsTypeEntity } from './modelType.entity';
import { SetModelTypeDto } from './dto/setModelType.dto';
@@ -19,8 +19,8 @@ export class ModelsService {
private readonly modelsEntity: Repository<ModelsEntity>,
@InjectRepository(ModelsTypeEntity)
private readonly modelsTypeEntity: Repository<ModelsTypeEntity>,
){}
) { }
private modelTypes = []
private modelMaps = {}
private keyList = {}
@@ -28,145 +28,145 @@ export class ModelsService {
private keyPoolMap = {} // 记录每个模型的所有key 并且记录顺序
private keyPoolIndexMap = {} // 记录每个模型的当前调用的下标
async onModuleInit(){
async onModuleInit() {
await this.initCalcKey()
this.refreshBaiduAccesstoken()
}
/* 初始化整理所有key 进行分类并且默认一个初始模型配置 默认是配置的第一个分类的第一个key为准 */
async initCalcKey(){
async initCalcKey() {
this.keyPoolMap = {}
this.keyPoolIndexMap = {}
this.keyList = {}
this.modelMaps = {}
this.modelTypes = []
const allKeys = await this.modelsEntity.find({where: { status: true }})
const keyTypes = allKeys.reduce( (pre: any, cur ) => {
if(!pre[cur.keyType]){
const allKeys = await this.modelsEntity.find({ where: { status: true } })
const keyTypes = allKeys.reduce((pre: any, cur) => {
if (!pre[cur.keyType]) {
pre[cur.keyType] = [cur]
}else{
} else {
pre[cur.keyType].push(cur)
}
return pre
}, {})
this.modelTypes = Object.keys(keyTypes).map( keyType => {
return { label: ModelsMapCn[keyType] , val: keyType}
this.modelTypes = Object.keys(keyTypes).map(keyType => {
return { label: ModelsMapCn[keyType], val: keyType }
})
this.modelMaps = keyTypes
this.keyList = {}
allKeys.forEach( keyDetail => {
allKeys.forEach(keyDetail => {
const { keyType, model, keyWeight } = keyDetail
if(!this.keyPoolMap[model]) this.keyPoolMap[model] = []
if (!this.keyPoolMap[model]) this.keyPoolMap[model] = []
for (let index = 0; index < keyWeight; index++) {
this.keyPoolMap[model].push(keyDetail)
}
if(!this.keyPoolIndexMap[model]) this.keyPoolIndexMap[model] = 0
if(!this.keyList[keyType]) this.keyList[keyType] = {}
if(!this.keyList[keyType][model]) this.keyList[keyType][model] = []
if (!this.keyPoolIndexMap[model]) this.keyPoolIndexMap[model] = 0
if (!this.keyList[keyType]) this.keyList[keyType] = {}
if (!this.keyList[keyType][model]) this.keyList[keyType][model] = []
this.keyList[keyType][model].push(keyDetail)
})
}
/* lock key 自动锁定key */
async lockKey(keyId, remark, keyStatus = -1){
async lockKey(keyId, remark, keyStatus = -1) {
const res = await this.modelsEntity.update({ id: keyId }, { status: false, keyStatus, remark });
Logger.error(`key: ${keyId} 欠费或被官方封禁导致不可用,已被系统自动锁定`);
this.initCalcKey()
}
/* 获取本次调用key的详细信息 */
async getCurrentModelKeyInfo(model){
if(!this.keyPoolMap[model]){
async getCurrentModelKeyInfo(model) {
if (!this.keyPoolMap[model]) {
throw new HttpException('当前调用模型已经被移除、请重新选择模型!', HttpStatus.BAD_REQUEST)
}
/* 调用下标+1 */
this.keyPoolIndexMap[model]++
/* 判断下标超出边界没有 */
const index = this.keyPoolIndexMap[model]
if(index >= this.keyPoolMap[model].length) this.keyPoolIndexMap[model] = 0
const key = this.keyPoolMap[model][this.keyPoolIndexMap[model]]
if (index >= this.keyPoolMap[model].length) this.keyPoolIndexMap[model] = 0
const key = this.keyPoolMap[model][this.keyPoolIndexMap[model]]
return key
}
/* 通过现有配置的key和分类给到默认的配置信息 默认给到第一个分类的第一个key的配置 */
async getBaseConfig(appId?: number): Promise<any>{
if(!this.modelTypes.length || !Object.keys(this.modelMaps).length) return;
async getBaseConfig(appId?: number): Promise<any> {
if (!this.modelTypes.length || !Object.keys(this.modelMaps).length) return;
/* 有appid只可以使用openai 的 模型 */
const modelTypeInfo = appId ? this.modelTypes.find( item => Number(item.val) === 1) : this.modelTypes[0]
const modelTypeInfo = appId ? this.modelTypes.find(item => Number(item.val) === 1) : this.modelTypes[0]
// TODO 第0个会有问题 先添加的4默认就是模型4了 后面优化下
if(!modelTypeInfo) return;
if (!modelTypeInfo) return;
const { keyType, modelName, model, maxModelTokens, maxResponseTokens, deductType, deduct, maxRounds } = this.modelMaps[modelTypeInfo.val][0] // 取到第一个默认的配置项信息
return {
modelTypeInfo,
modelInfo: { keyType, modelName, model, maxModelTokens, maxResponseTokens, topN: 0.8, systemMessage: '', deductType, deduct, maxRounds, rounds: 8 }
modelInfo: { keyType, modelName, model, maxModelTokens, maxResponseTokens, topN: 0.8, systemMessage: '', deductType, deduct, maxRounds, rounds: 8 }
}
}
async setModel(params: SetModelDto){
try {
const { id } = params
params.status && (params.keyStatus = 1)
if(id){
const res = await this.modelsEntity.update({id}, params)
await this.initCalcKey()
return res.affected > 0
}else{
const { keyType, key } = params
if(Number(keyType !== 1)){
const res = await this.modelsEntity.save(params)
async setModel(params: SetModelDto) {
try {
const { id } = params
params.status && (params.keyStatus = 1)
if (id) {
const res = await this.modelsEntity.update({ id }, params)
await this.initCalcKey()
if(keyType === 2){ //百度的需要刷新token
this.refreshBaiduAccesstoken()
return res.affected > 0
} else {
const { keyType, key } = params
if (Number(keyType !== 1)) {
const res = await this.modelsEntity.save(params)
await this.initCalcKey()
return res
} else {
const data = key.map(k => {
try {
const data = JSON.parse(JSON.stringify(params))
data.key = k
return data
} catch (error) {
console.log('parse error: ', error);
}
})
const res = await this.modelsEntity.save(data)
await this.initCalcKey()
return res
}
return res
}else{
const data = key.map( k => {
try {
const data = JSON.parse(JSON.stringify(params))
data.key = k
return data
} catch (error) {
console.log('parse error: ', error);
}
})
const res = await this.modelsEntity.save(data)
await this.initCalcKey()
return res
}
} catch (error) {
console.log('error: ', error);
}
} catch (error) {
console.log('error: ', error);
}
}
async delModel({id}){
if(!id) {
async delModel({ id }) {
if (!id) {
throw new HttpException('缺失必要参数!', HttpStatus.BAD_REQUEST)
}
const m = await this.modelsEntity.findOne({where: {id}})
if(!m){
const m = await this.modelsEntity.findOne({ where: { id } })
if (!m) {
throw new HttpException('当前账号不存在!', HttpStatus.BAD_REQUEST)
}
const res = await this.modelsEntity.delete({id})
const res = await this.modelsEntity.delete({ id })
await this.initCalcKey()
return res;
}
async queryModels(req, params: QueryModelDto){
async queryModels(req, params: QueryModelDto) {
const { role } = req.user
const { keyType, key, status, model, page = 1, size = 10 } = params
let where: any = {}
keyType && (where.keyType = keyType)
model && (where.model = model)
status && (where.status = Number(status) === 1 ? true : false)
key && ( where.key = Like(`%${key}%`))
key && (where.key = Like(`%${key}%`))
const [rows, count] = await this.modelsEntity.findAndCount({
where: where,
skip: (page - 1) * size,
take: size,
where: where,
order: {
modelOrder: 'ASC'
},
skip: (page - 1) * size,
take: size,
})
if(role !== 'super'){
rows.forEach( item => {
if (role !== 'super') {
rows.forEach(item => {
item.key && (item.key = hideString(item.key))
item.secret && (item.secret = hideString(item.secret))
})
@@ -176,24 +176,29 @@ export class ModelsService {
}
/* 客户端查询到的所有的配置的模型类别 以及类别下自定义的多少中文模型名称 */
async modelsList(){
const cloneModelMaps = JSON.parse(JSON.stringify(this.modelMaps))
Object.keys(cloneModelMaps).forEach( key => {
async modelsList() {
const cloneModelMaps = JSON.parse(JSON.stringify(this.modelMaps));
Object.keys(cloneModelMaps).forEach(key => {
// 对每个模型进行排序
cloneModelMaps[key] = cloneModelMaps[key].sort((a, b) => a.modelOrder - b.modelOrder);
cloneModelMaps[key] = Array.from(
cloneModelMaps[key].map( t => {
const { modelName, model, deduct, deductType, maxRounds } = t
return { modelName, model, deduct, deductType, maxRounds }
}).reduce((map, obj) => map.set(obj.modelName, obj), new Map()).values()
cloneModelMaps[key]
.map(t => {
const { modelName, model, deduct, deductType, maxRounds } = t;
return { modelName, model, deduct, deductType, maxRounds };
})
.reduce((map, obj) => map.set(obj.modelName, obj), new Map()).values()
);
})
});
return {
modelTypeList: this.modelTypes,
modelMaps: cloneModelMaps
}
};
}
/* 记录使用次数和使用的token数量 */
async saveUseLog(id, useToken){
async saveUseLog(id, useToken) {
await this.modelsEntity
.createQueryBuilder()
.update(ModelsEntity)
@@ -202,54 +207,32 @@ export class ModelsService {
.execute();
}
async refreshBaiduAccesstoken(){
const allKeys = await this.modelsEntity.find({ where: { keyType: 2 } })
const keysMap: any = {}
allKeys.forEach( keyInfo => {
const { key, secret } = keyInfo
if(!keysMap.key){
keysMap[key] = [{ keyInfo }]
}else{
keysMap[key].push(keyInfo)
}
})
Object.keys(keysMap).forEach( async key => {
const {secret, id } = keysMap[key][0]['keyInfo']
const accessToken: any = await getAccessToken(key, secret)
await this.modelsEntity.update({ key }, { accessToken })
})
setTimeout(() => {
this.initCalcKey()
}, 1000)
}
/* 获取一张绘画key */
async getRandomDrawKey(){
const drawkeys = await this.modelsEntity.find({where: { isDraw: true, status: true }})
if(!drawkeys.length){
async getRandomDrawKey() {
const drawkeys = await this.modelsEntity.find({ where: { isDraw: true, status: true } })
if (!drawkeys.length) {
throw new HttpException('当前未指定特殊模型KEY、前往后台模型池设置吧', HttpStatus.BAD_REQUEST)
}
return getRandomItemFromArray(drawkeys)
}
/* 获取所有key */
async getAllKey(){
async getAllKey() {
return await this.modelsEntity.find()
}
/* 查询模型类型 */
async queryModelType(params: QueryModelTypeDto){
async queryModelType(params: QueryModelTypeDto) {
return 1
}
/* 创建修改模型类型 */
async setModelType(params: SetModelTypeDto){
async setModelType(params: SetModelTypeDto) {
return 1
}
/* 删除模型类型 */
async delModelType(params){
async delModelType(params) {
return 1
}

View File

@@ -20,9 +20,9 @@ export class MjDrawDto {
@IsOptional()
imgUrl?: string;
@ApiProperty({ example: 1, description: '绘画动作 绘图、放大、变换、图生图' })
@ApiProperty({ example: 'IMAGINE', description: '任务类型,可用值:IMAGINE,UPSCALE,VARIATION,ZOOM,PAN,DESCRIBE,BLEND,SHORTEN,SWAP_FACE' })
@IsOptional()
action: number;
action: string;
@ApiProperty({ example: 1, description: '变体或者放大的序号' })
@IsOptional()
@@ -31,4 +31,8 @@ export class MjDrawDto {
@ApiProperty({ example: 1, description: '绘画的DBID' })
@IsOptional()
drawId: number;
@ApiProperty({ example: 1, description: '任务ID' })
@IsOptional()
taskId: number;
}

View File

@@ -15,7 +15,7 @@ export class QueueService implements OnApplicationBootstrap {
private readonly midjourneyService: MidjourneyService,
private readonly userBalanceService: UserBalanceService,
private readonly globalConfigService: GlobalConfigService,
) {}
) { }
private readonly jobIds: any[] = [];
async onApplicationBootstrap() {
@@ -27,14 +27,13 @@ export class QueueService implements OnApplicationBootstrap {
/* 提交绘画任务 */
async addMjDrawQueue(body: MjDrawDto, req: Request) {
const { prompt, imgUrl, extraParam, orderId, action = 1, drawId } = body;
const { imgUrl, orderId, action, drawId } = body;
/* 限制普通用户队列最多可以有两个任务在排队或者等待中 */
await this.midjourneyService.checkLimit(req);
/* 检测余额 */
await this.userBalanceService.validateBalance(req, 'mjDraw', action === 2 ? 1 : 4);
await this.userBalanceService.validateBalance(req, 'mjDraw', action === 'UPSCALE' ? 1 : 4);
/* 绘图或者图生图 */
if (action === MidjourneyActionEnum.DRAW || action === MidjourneyActionEnum.GENERATE) {
if (action === 'IMAGINE') {
/* 绘图或者图生图是相同的 区分一个action即可 */
const randomDrawId = `${createRandomUid()}`;
const params = { ...body, userId: req.user.id, randomDrawId };
@@ -44,84 +43,28 @@ export class QueueService implements OnApplicationBootstrap {
/* 添加任务到队列 通过imgUrl判断是不是图生图 */
const job = await this.mjDrawQueue.add(
'mjDraw',
{ id: res.id, action: imgUrl ? 4 : 1, userId: req.user.id },
{ id: res.id, action: action, userId: req.user.id },
{ delay: 1000, timeout: +timeout },
);
/* 绘图和图生图扣除余额4 */
this.jobIds.push(job.id);
/* 扣费 */
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return true;
} else {
const { orderId, action, drawId } = body;
const actionDetail = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
return;
}
if (!drawId || !orderId) {
throw new HttpException('缺少必要参数!', HttpStatus.BAD_REQUEST);
}
/* 图片操作 */
/* 图片放大 */
if (action === MidjourneyActionEnum.UPSCALE) {
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const { custom_id } = actionDetail;
/* 检测当前图片是不是已经放大过了 */
await this.midjourneyService.checkIsUpscale(custom_id);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
/* 扣费 */
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
this.jobIds.push(job.id);
return;
}
/* 图片变体 */
if (action === MidjourneyActionEnum.VARIATION) {
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
/* 扣费 */
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
/* 重新生成 */
if (action === MidjourneyActionEnum.REGENERATE) {
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
/* 对图片增强 Vary */
if (action === MidjourneyActionEnum.VARY) {
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
/* 对图片缩放 Zoom */
if (action === MidjourneyActionEnum.ZOOM) {
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
const params = { ...body, userId: req.user.id, ...actionDetail };
const res = await this.midjourneyService.addDrawQueue(params);
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
}
/* 查询队列 */

View File

@@ -1,4 +1,4 @@
import { GlobalConfigService } from '../globalConfig/globalConfig.service';
import { GlobalConfigService } from './../globalConfig/globalConfig.service';
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
import { Cron, CronExpression } from '@nestjs/schedule';
import { UserBalanceEntity } from '../userBalance/userBalance.entity';
@@ -13,7 +13,7 @@ export class TaskService {
private readonly userBalanceEntity: Repository<UserBalanceEntity>,
private readonly globalConfigService: GlobalConfigService,
private readonly modelsService: ModelsService,
) {}
) { }
/* 每小时刷新一次微信的token */
@Cron(CronExpression.EVERY_HOUR)
@@ -40,8 +40,8 @@ export class TaskService {
}
/* 每小时检测一次授权 */
@Cron('0 0 */5 * *')
refreshBaiduAccesstoken() {
this.modelsService.refreshBaiduAccesstoken();
}
// @Cron('0 0 */5 * *')
// refreshBaiduAccesstoken() {
// this.modelsService.refreshBaiduAccesstoken();
// }
}

View File

@@ -10,33 +10,46 @@ import * as FormData from 'form-data';
@Injectable()
export class UploadService implements OnModuleInit {
constructor(private readonly globalConfigService: GlobalConfigService) {}
constructor(private readonly globalConfigService: GlobalConfigService) { }
private tencentCos: any;
onModuleInit() {}
onModuleInit() { }
async uploadFile(file) {
const { filename: name, originalname, buffer, dir = 'ai', mimetype } = file;
const fileTyle = mimetype ? mimetype.split('/')[1] : '';
const filename = originalname || name
Logger.debug(`准备上传文件: ${filename}, 类型: ${fileTyle}`, 'UploadService');
const {
tencentCosStatus = 0,
aliOssStatus = 0,
cheveretoStatus = 0,
} = await this.globalConfigService.getConfigs(['tencentCosStatus', 'aliOssStatus', 'cheveretoStatus']);
Logger.debug(`上传配置状态 - 腾讯云: ${tencentCosStatus}, 阿里云: ${aliOssStatus}, Chevereto: ${cheveretoStatus}`, 'UploadService');
if (!Number(tencentCosStatus) && !Number(aliOssStatus) && !Number(cheveretoStatus)) {
throw new HttpException('请先前往后台配置上传图片的方式', HttpStatus.BAD_REQUEST);
}
if (Number(tencentCosStatus)) {
return this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle });
}
if (Number(aliOssStatus)) {
return await this.uploadFileByAliOss({ filename, buffer, dir, fileTyle });
}
if (Number(cheveretoStatus)) {
const { filename, buffer: fromBuffer, dir } = file;
return await this.uploadFileByChevereto({ filename, buffer: fromBuffer.toString('base64'), dir, fileTyle });
try {
if (Number(tencentCosStatus)) {
Logger.debug(`使用腾讯云COS上传`, 'UploadService');
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle });
}
if (Number(aliOssStatus)) {
Logger.debug(`使用阿里云OSS上传`, 'UploadService');
return await this.uploadFileByAliOss({ filename, buffer, dir, fileTyle });
}
if (Number(cheveretoStatus)) {
Logger.debug(`使用Chevereto上传`, 'UploadService');
const { filename, buffer: fromBuffer, dir } = file;
return await this.uploadFileByChevereto({ filename, buffer: fromBuffer.toString('base64'), dir, fileTyle });
}
} catch (error) {
Logger.error(`上传失败: ${error.message}`, 'UploadService');
throw error; // 重新抛出异常,以便调用方可以处理
}
}
@@ -122,23 +135,10 @@ export class UploadService implements OnModuleInit {
this.tencentCos = new TENCENTCOS({ SecretId, SecretKey, FileParallelLimit: 10 });
try {
const proxyMj = (await this.globalConfigService.getConfigs(['mjProxy'])) || 0;
/* 开启代理 */
if (Number(proxyMj) === 1) {
const data = { cosType: 'tencent', url, cosParams: { Bucket, Region, SecretId, SecretKey } };
const mjProxyUrl = (await this.globalConfigService.getConfigs(['mjProxyUrl'])) || 'http://172.247.48.137:8000';
const res = await axios.post(`${mjProxyUrl}/mj/replaceUpload`, data);
if (!res.data) throw new HttpException('上传图片失败[ten][url]', HttpStatus.BAD_REQUEST);
let locationUrl = res.data.replace(/^(http:\/\/|https:\/\/|\/\/|)(.*)/, 'https://$2');
const { acceleratedDomain } = await this.getUploadConfig('tencent');
if (acceleratedDomain) {
locationUrl = locationUrl.replace(/^(https:\/\/[^/]+)(\/.*)$/, `https://${acceleratedDomain}$2`);
console.log('当前已开启全球加速----------------->');
}
return locationUrl;
} else {
const buffer = await this.getBufferFromUrl(url);
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle: '' });
}
const buffer = await this.getBufferFromUrl(url);
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle: '' });
} catch (error) {
console.log('TODO->error: ', error);
throw new HttpException('上传图片失败[ten][url]', HttpStatus.BAD_REQUEST);