新增gpt-4V上传 修复部分bug

This commit is contained in:
小易
2024-01-28 18:41:04 +08:00
parent 6dc767f009
commit 8db214371a
26 changed files with 1043 additions and 685 deletions

BIN
service/src/.DS_Store vendored

Binary file not shown.

View File

@@ -32,7 +32,7 @@ async function bootstrap() {
createSwagger(app);
const server = await app.listen(PORT, () => {
Logger.log(`服务启动成功: http://localhost:${PORT}/nineai/swagger/docs 作者:小易 QQ805239273`, 'Main');
Logger.log(`服务启动成功: http://localhost:${PORT}/nineai/swagger/docs`, 'Main');
});
server.timeout = 5 * 60 * 1000;
}

View File

@@ -58,6 +58,9 @@ export class ChatLogEntity extends BaseEntity {
@Column({ comment: '图片信息的string', nullable: true, type: 'text' })
fileInfo: string;
@Column({ comment: '上传图片的信息', nullable: true, type: 'text' })
imageUrl: string;
@Column({ comment: 'role system user assistant', nullable: true })
role: string;

View File

@@ -211,7 +211,7 @@ export class ChatLogService {
}
const list = await this.chatLogEntity.find({ where });
return list.map((item) => {
const { prompt, role, answer, createdAt, model, conversationOptions, requestOptions, id } = item;
const { prompt, role, answer, createdAt, model, conversationOptions, requestOptions, id,imageUrl} = item;
let parseConversationOptions: any = null
let parseRequestOptions: any = null
try {
@@ -228,6 +228,8 @@ export class ChatLogService {
error: false,
conversationOptions: parseConversationOptions,
requestOptions: parseRequestOptions,
imageUrl,
model
};
});
}

View File

@@ -165,7 +165,7 @@ export class ChatgptService implements OnModuleInit {
/* 不同场景会变更其信息 */
let setSystemMessage = systemMessage;
const { parentMessageId } = options;
const { prompt } = body;
const { prompt ,imageUrl,model:activeModel} = body;
const { groupId, usingNetwork } = options;
// const { model = 3 } = options;
/* 获取当前对话组的详细配置信息 */
@@ -260,6 +260,8 @@ export class ChatgptService implements OnModuleInit {
userId: req.user.id,
type: DeductionKey.CHAT_TYPE,
prompt,
imageUrl,
activeModel,
answer: '',
promptTokens: prompt_tokens,
completionTokens: 0,
@@ -318,6 +320,8 @@ export class ChatgptService implements OnModuleInit {
const { context: messagesHistory } = await this.nineStore.buildMessageFromParentMessageId(usingNetwork ? netWorkPrompt : prompt, {
parentMessageId,
systemMessage,
imageUrl,
activeModel,
maxModelToken: maxToken,
maxResponseTokens: maxTokenRes,
maxRounds: addOneIfOdd(rounds),
@@ -328,6 +332,8 @@ export class ChatgptService implements OnModuleInit {
maxTokenRes,
apiKey: modelKey,
model,
imageUrl,
activeModel,
temperature,
proxyUrl: proxyResUrl,
onProgress: (chat) => {
@@ -386,6 +392,8 @@ export class ChatgptService implements OnModuleInit {
role: 'user',
name: undefined,
usage: null,
imageUrl,
activeModel,
parentMessageId: parentMessageId,
conversationId: response?.conversationId,
};
@@ -449,6 +457,8 @@ export class ChatgptService implements OnModuleInit {
userId: req.user.id,
type: DeductionKey.CHAT_TYPE,
prompt,
imageUrl,
activeModel,
answer: '',
promptTokens: prompt_tokens,
completionTokens: 0,

View File

@@ -102,7 +102,11 @@ export function sendMessageFromOpenAi(messagesHistory, inputs ){
export function getTokenCount(text: string) {
if(!text) return 0;
if (!text) return 0;
// 确保text是字符串类型
if (typeof text !== 'string') {
text = String(text);
}
text = text.replace(/<\|endoftext\|>/g, '')
return tokenizer.encode(text).length
}

View File

@@ -1,41 +1,44 @@
import Keyv from 'keyv'
import { v4 as uuidv4 } from "uuid";
import { get_encoding } from '@dqbd/tiktoken'
import Keyv from 'keyv';
import { v4 as uuidv4 } from 'uuid';
import { get_encoding } from '@dqbd/tiktoken';
import { Logger } from '@nestjs/common';
const tokenizer = get_encoding('cl100k_base')
export type Role = 'user' | 'assistant' | 'function'
const tokenizer = get_encoding('cl100k_base');
export type Role = 'user' | 'assistant' | 'function';
interface Options {
store: Keyv
namespace: string
expires?: number
store: Keyv;
namespace: string;
expires?: number;
}
export interface MessageInfo {
id: string
text: string
role: Role
name?: string
id: string;
text: string;
role: Role;
name?: string;
imageUrl?: string;
activeModel?: string;
usage: {
prompt_tokens?: number
completion_tokens?: number
total_tokens?: number
}
parentMessageId?: string
conversationId?: string
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
parentMessageId?: string;
conversationId?: string;
}
export interface BuildMessageOptions {
systemMessage?: string
parentMessageId: string
maxRounds?: number
maxModelToken?: number
maxResponseTokens?: number
name?: string
systemMessage?: string;
parentMessageId: string;
maxRounds?: number;
maxModelToken?: number;
maxResponseTokens?: number;
name?: string;
imageUrl?: string;
activeModel?: string;
model?: string;
}
// export interface BuildMessageRes {
@@ -43,12 +46,12 @@ export interface BuildMessageOptions {
// numTokens: number
// maxTokens: number
// }
export type BuildMessageRes = any[]
export type BuildMessageRes = any[];
export interface NineStoreInterface {
getData(id: string): Promise<string>;
setData(message: MessageInfo, expires?: number): Promise<void>;
getUuid(): string
getUuid(): string;
buildMessageFromParentMessageId(string, opt?: BuildMessageOptions): Promise<any>;
}
@@ -58,115 +61,165 @@ export class NineStore implements NineStoreInterface {
private expires: number;
constructor(options: Options) {
const { store, namespace, expires } = this.formatOptions(options)
this.store = store
this.namespace = namespace
this.expires = expires
const { store, namespace, expires } = this.formatOptions(options);
this.store = store;
this.namespace = namespace;
this.expires = expires;
}
public formatOptions(options: Options){
const { store, expires = 1000 * 60 * 60 * 24 * 3, namespace = 'chat'} = options
return { store, namespace, expires }
public formatOptions(options: Options) {
const { store, expires = 1000 * 60 * 60 * 24 * 3, namespace = 'chat' } = options;
return { store, namespace, expires };
}
public generateKey(key){
return this.namespace ? `${this.namespace}-${key}` : key
public generateKey(key: any) {
return this.namespace ? `${this.namespace}-${key}` : key;
}
public async getData(id: string ): Promise<any> {
const res = await this.store.get(id)
return res
public async getData(id: string): Promise<any> {
const res = await this.store.get(id);
return res;
}
public async setData(message, expires = this.expires): Promise<void> {
await this.store.set(message.id, message, expires)
public async setData(message: MessageInfo, expires = this.expires): Promise<void> {
await this.store.set(message.id, message, expires);
}
/**
* @desc 通过传入prompt和parentMessageId 递归往上拿到历史记录组装为模型需要的上下文、
* 可以传入maxRounds限制最大轮次的对话 传入maxModelToken, maxResponseTokens 则通过计算上下文占用的token计算出最大容量
*/
public async buildMessageFromParentMessageId(text: string, options: BuildMessageOptions){
let { maxRounds, maxModelToken, maxResponseTokens, systemMessage = '', name } = options
let { parentMessageId } = options
let messages = []
let nextNumTokensEstimate = 0
public async buildMessageFromParentMessageId(text: string, options: BuildMessageOptions) {
let { maxRounds, maxModelToken, maxResponseTokens, systemMessage = '', name, imageUrl, model, activeModel } = options;
let { parentMessageId } = options;
let messages = [];
let nextNumTokensEstimate = 0;
// messages.push({ role: 'system', content: systemMessage, name })
if (systemMessage) {
messages.push({ role: 'system', content: systemMessage })
const specialModels = ['gemini-pro', 'ERNIE', 'qwen', 'SparkDesk', 'hunyuan'];
const isSpecialModel = activeModel && specialModels.some((specialModel) => activeModel.includes(specialModel));
if (isSpecialModel) {
messages.push({ role: 'user', content: systemMessage, name });
messages.push({ role: 'assistant', content: '好的', name });
} else {
messages.push({ role: 'system', content: systemMessage, name });
}
}
const systemMessageOffset = messages.length
let round = 0
let nextMessages = text ? messages.concat([{ role: 'user', content: text, name }]) : messages
do {
// let parentId = '1bf30262-8f25-4a03-88ad-9d42d55e6f0b'
/* 没有parentMessageId就没有历史 直接返回 */
if(!parentMessageId){
break;
const systemMessageOffset = messages.length;
let round = 0;
// 特殊处理 gpt-4-vision-preview 模型
if (activeModel === 'gpt-4-vision-preview' && imageUrl) {
const content = [
{
type: 'text',
text: text,
},
{
type: 'image_url',
image_url: {
url: imageUrl,
},
},
];
messages.push({ role: 'user', content: content, name });
} else {
// 处理 gpt-4-all 模型
if (model === 'gpt-4-all' && imageUrl) {
text = imageUrl + '\n' + text;
}
const parentMessage = await this.getData(parentMessageId)
messages.push({ role: 'user', content: text, name });
}
// Logger.debug(`发送的参数:${messages}`)
if(!parentMessage){
let nextMessages = messages;
do {
// let parentId = '1bf30262-8f25-4a03-88ad-9d42d55e6f0b'
/* 没有parentMessageId就没有历史 直接返回 */
if (!parentMessageId) {
break;
}
const { text, name, role } = parentMessage
const parentMessage = await this.getData(parentMessageId);
if (!parentMessage) {
break;
}
const { text, name, role, imageUrl } = parentMessage;
let content = text; // 默认情况下使用text作为content
// 特别处理包含 imageUrl 的消息
if (role === 'user' && imageUrl) {
if (activeModel === 'gpt-4-vision-preview') {
content = [
{ type: 'text', text: text },
{ type: 'image_url', image_url: { url: imageUrl } },
];
}
}
/* 将本轮消息插入到列表中 */
nextMessages = nextMessages.slice(0, systemMessageOffset).concat([
{ role, content: text, name },
...nextMessages.slice(systemMessageOffset)
])
round++
{ role, content, name }, // 使用调整后的content
...nextMessages.slice(systemMessageOffset),
]);
// Logger.debug(`nextMessages${JSON.stringify(nextMessages, null, 2)}`);
round++;
/* 如果超出了限制的最大轮次 就退出 不包含本次发送的本身 */
if(maxRounds && round >= maxRounds){
if (maxRounds && round >= maxRounds) {
break;
}
/* 如果传入maxModelToken maxResponseTokens 则判断是否超过边界 */
if(maxModelToken && maxResponseTokens){
const maxNumTokens = maxModelToken - maxResponseTokens // 模型最大token限制减去限制回复剩余空间
if (maxModelToken && maxResponseTokens) {
const maxNumTokens = maxModelToken - maxResponseTokens; // 模型最大token限制减去限制回复剩余空间
/* 当前的对话历史列表合并的总token容量 */
nextNumTokensEstimate = await this._getTokenCount(nextMessages)
nextNumTokensEstimate = await this._getTokenCount(nextMessages);
/* 200是添加的一个安全区间 防止少量超过 待优化 */
const isValidPrompt = nextNumTokensEstimate + 200 <= maxNumTokens
const isValidPrompt = nextNumTokensEstimate + 200 <= maxNumTokens;
/* 如果大于这个区间了说明本轮加入之后导致超过限制、则递归删除头部的对话轮次来保证不出边界 */
if(!isValidPrompt){
nextMessages = this._recursivePruning(nextMessages, maxNumTokens, systemMessage)
if (!isValidPrompt) {
nextMessages = this._recursivePruning(nextMessages, maxNumTokens, systemMessage);
}
}
parentMessageId = parentMessage.parentMessageId
parentMessageId = parentMessage.parentMessageId;
} while (true);
const maxTokens = Math.max(
1,
Math.min(maxModelToken - nextNumTokensEstimate, maxResponseTokens)
)
const maxTokens = Math.max(1, Math.min(maxModelToken - nextNumTokensEstimate, maxResponseTokens));
// Logger.debug(`本轮调用:模型:${model}`)
console.log('本次携带上下文的长度',nextMessages.length, nextNumTokensEstimate )
return { context: nextMessages, round: nextMessages.length, historyToken:nextNumTokensEstimate }
console.log('本次携带上下文的长度', nextMessages.length, nextNumTokensEstimate);
return { context: nextMessages, round: nextMessages.length, historyToken: nextNumTokensEstimate };
}
protected _getTokenCount(messages: any[]) {
let text = messages.reduce( (pre: string, cur: any) => {
return pre+=cur.content
}, '')
text = text.replace(/<\|endoftext\|>/g, '')
return tokenizer.encode(text).length
}
let text = messages.reduce((pre: string, cur: any) => {
// 检查cur.content是否为数组
if (Array.isArray(cur.content)) {
// 提取并连接数组中的文本元素
const contentText = cur.content
.filter((item: { type: string }) => item.type === 'text')
.map((item: { text: any }) => item.text)
.join(' ');
return pre + contentText;
} else {
// 如果不是数组,则直接添加
return pre + (cur.content || '');
}
}, '');
text = text.replace(/<\|endoftext\|>/g, '');
return tokenizer.encode(text).length;
}
/* 递归删除 当token超过模型限制容量 删除到在限制区域内 */
protected _recursivePruning(
messages: MessageInfo[],
maxNumTokens: number,
systemMessage: string
) {
const currentTokens = this._getTokenCount(messages)
protected _recursivePruning(messages: MessageInfo[], maxNumTokens: number, systemMessage: string) {
const currentTokens = this._getTokenCount(messages);
if (currentTokens <= maxNumTokens) {
return messages
return messages;
}
/* 有系统预设则跳过第一条删除 没有则直接删除 */
messages.splice(systemMessage ? 1 : 0, 1)
return this._recursivePruning(messages, maxNumTokens, systemMessage)
messages.splice(systemMessage ? 1 : 0, 1);
return this._recursivePruning(messages, maxNumTokens, systemMessage);
}
public getUuid(){
return uuidv4()
public getUuid() {
return uuidv4();
}
}

View File

@@ -32,7 +32,6 @@ export class MidjourneyService {
private redisCacheService: RedisCacheService,
) {}
private lockPrompt = [];
/* 睡眠 xs */
@@ -111,6 +110,9 @@ export class MidjourneyService {
await this.updateDrawData(jobData, drawRes);
/* 存完解锁当前文件 */
this.lockPrompt = this.lockPrompt.filter((item) => item !== drawInfo.randomDrawId);
/* 只有在画成功后才扣分*/
this.drawSuccess(jobData);
}
return true;
@@ -160,16 +162,16 @@ export class MidjourneyService {
const { id, content, channel_id, attachments = [], timestamp, durationSpent } = drawRes;
const { filename, url, proxy_url, width, height, size } = attachments[0];
/* 将图片存入cos */
const mjNotSaveImg = await this.globalConfigService.getConfigs(['mjNotSaveImg'])
let cosUrl = ''
if(!Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0){
const mjNotSaveImg = await this.globalConfigService.getConfigs(['mjNotSaveImg']);
let cosUrl = '';
if (!Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0) {
Logger.debug(`------> 开始上传图片!!!`, 'MidjourneyService');
const startDate = new Date();
cosUrl = await this.uploadService.uploadFileFromUrl({ filename, url });
const endDate = new Date();
Logger.debug(`本次图片上传耗时为${(endDate.getTime() - startDate.getTime()) / 1000}`, 'MidjourneyService');
}else{
console.log('本次不存图片了')
} else {
console.log('本次不存图片了');
}
/* 记录当前图片存储方式 方便后续对不同平台图片压缩 */
const cosType = await this.uploadService.getUploadType();
@@ -181,11 +183,11 @@ export class MidjourneyService {
fileInfo: JSON.stringify({ width, height, size, filename, cosUrl, cosType }),
extend: this.removeEmoji(JSON.stringify(drawRes)),
durationSpent,
isSaveImg: !Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0,
isSaveImg: !Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0,
};
await this.midjourneyEntity.update({ id: jobData.id }, drawInfo);
} catch (error) {
console.log('TODO->存储图片失败, ', jobData,error);
console.log('TODO->存储图片失败, ', jobData, error);
}
}
@@ -734,17 +736,15 @@ export class MidjourneyService {
take: size,
skip: (page - 1) * size,
});
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
rows.forEach((item: any) => {
try {
const { extend, isSaveImg, fileInfo } = item;
const originUrl = JSON.parse(extend)?.attachments[0]?.url
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
item.originUrl = originUrl
} catch (error) {
}
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
item.originUrl = originUrl;
} catch (error) {}
});
const countQueue = await this.midjourneyEntity.count({ where: { isDelete: 0, status: In([1, 2]) } });
const data: any = { rows: formatCreateOrUpdateDate(rows), count, countQueue };
@@ -755,18 +755,18 @@ export class MidjourneyService {
}
/* 格式化fileinfo 对于不同平台的图片进行压缩 */
formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl) {
formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl) {
if (!fileInfo) return {};
let parseFileInfo: any = null
let parseFileInfo: any = null;
try {
parseFileInfo = JSON.parse(fileInfo);
} catch (error) {
parseFileInfo = null
parseFileInfo = null;
}
if(!parseFileInfo) return;
if (!parseFileInfo) return;
const { url, filename, size, cosUrl, width, height } = parseFileInfo;
const targetSize = 310; // 目标宽度或高度
// TODO判断逻辑有误 腾讯云会导致也判断为 chevereto 更换判断规则
const imgType = cosUrl.includes('cos') ? 'tencent' : cosUrl.includes('oss') ? 'ali' : 'chevereto';
let compress;
@@ -786,10 +786,10 @@ export class MidjourneyService {
}
parseFileInfo.thumbImg = thumbImg;
/* 如果配置了不存储图片 则 isSaceImg 为false的则需要使用反代地址拼接 */
if(!isSaveImg){
const proxyImgUrl = `${mjProxyImgUrl}/mj/pipe?url=${originUrl}`
parseFileInfo.thumbImg = proxyImgUrl
parseFileInfo.cosUrl = proxyImgUrl
if (!isSaveImg) {
const proxyImgUrl = `${mjProxyImgUrl}/mj/pipe?url=${originUrl}`;
parseFileInfo.thumbImg = proxyImgUrl;
parseFileInfo.cosUrl = proxyImgUrl;
}
return parseFileInfo;
}
@@ -859,8 +859,8 @@ export class MidjourneyService {
// return;
// }
const count = await this.midjourneyEntity.count({ where: { userId: id, isDelete: 0, status: In([1, 2]) } });
const mjLimitCount = await this.globalConfigService.getConfigs(['mjLimitCount'])
const max = mjLimitCount ? Number(mjLimitCount) : 2
const mjLimitCount = await this.globalConfigService.getConfigs(['mjLimitCount']);
const max = mjLimitCount ? Number(mjLimitCount) : 2;
if (count >= max) {
throw new HttpException(`当前管理员限制单用户同时最多能执行${max}个任务`, HttpStatus.BAD_REQUEST);
}
@@ -870,11 +870,21 @@ export class MidjourneyService {
async drawFailed(jobData) {
const { id, userId, action } = jobData;
/* 退还余额 放大图片类型2是1 其他都是4 */
const amount = action === 2 ? 1 : 4;
await this.userBalanceService.refundMjBalance(userId, amount);
// const amount = action === 2 ? 1 : 4;
// await this.userBalanceService.refundMjBalance(userId, amount);
await this.midjourneyEntity.update({ id }, { status: 4 });
}
/* 绘图成功扣费 */
async drawSuccess(jobData) {
const { id, userId, action } = jobData;
/* 扣除余额 放大图片类型2是1 其他都是4 */
const amount = action === 2 ? 1 : 4;
Logger.debug(`绘画完成,执行扣费,扣除费用:${amount}积分。`);
await this.userBalanceService.refundMjBalance(userId, -amount);
await this.midjourneyEntity.update({ id }, { status: 3 });
}
/* 获取绘画列表 */
async getList(params: GetListDto) {
const { page = 1, size = 20, rec, userId, status } = params;
@@ -902,17 +912,15 @@ export class MidjourneyService {
skip: (page - 1) * size,
select: ['fileInfo', 'extend', 'prompt', 'createdAt', 'id', 'extend', 'fullPrompt', 'rec', 'isSaveImg'],
});
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
rows.forEach((item: any) => {
try {
const { extend, isSaveImg, fileInfo } = item;
const originUrl = JSON.parse(extend)?.attachments[0]?.url
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
item.originUrl = originUrl
} catch (error) {
}
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
item.originUrl = originUrl;
} catch (error) {}
});
if (Number(size) === 999) {
@@ -931,10 +939,10 @@ export class MidjourneyService {
}
/* */
async getFullPrompt(id: number){
const m = await this.midjourneyEntity.findOne({where: {id}})
if(!m) return ''
const { fullPrompt } = m
async getFullPrompt(id: number) {
const m = await this.midjourneyEntity.findOne({ where: { id } });
if (!m) return '';
const { fullPrompt } = m;
return fullPrompt;
}
@@ -953,27 +961,25 @@ export class MidjourneyService {
skip: (page - 1) * size,
});
const userIds = rows.map((item: any) => item.userId).filter( id => id < 100000);
const userIds = rows.map((item: any) => item.userId).filter((id) => id < 100000);
const userInfos = await this.userEntity.find({ where: { id: In(userIds) }, select: ['id', 'username', 'avatar', 'email'] });
rows.forEach((item: any) => {
item.userInfo = userInfos.find((user) => user.id === item.userId);
});
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
rows.forEach((item: any) => {
try {
const { extend, isSaveImg, fileInfo } = item;
const originUrl = JSON.parse(extend)?.attachments[0]?.url
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
// item.isGroup = JSON.parse(extend)?.components[0]?.components.length === 5;
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
item.originUrl = originUrl
} catch (error) {
}
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
item.originUrl = originUrl;
} catch (error) {}
});
if (req.user.role !== 'super') {
rows.forEach((item: any) => {
if(item.userInfo && item.userInfo.email){
if (item.userInfo && item.userInfo.email) {
item.userInfo.email = item.userInfo.email.replace(/(.{2}).+(.{2}@.+)/, '$1****$2');
}
});
@@ -1021,38 +1027,38 @@ export class MidjourneyService {
}
}
async setPrompt(req: Request, body){
async setPrompt(req: Request, body) {
try {
const { prompt, status, isCarryParams, title, order, id, aspect } = body
if(id){
return await this.mjPromptsEntity.update({id}, {prompt, status, isCarryParams, order, aspect})
}else{
return await this.mjPromptsEntity.save({prompt, status, isCarryParams, title, order, aspect})
}
const { prompt, status, isCarryParams, title, order, id, aspect } = body;
if (id) {
return await this.mjPromptsEntity.update({ id }, { prompt, status, isCarryParams, order, aspect });
} else {
return await this.mjPromptsEntity.save({ prompt, status, isCarryParams, title, order, aspect });
}
} catch (error) {
console.log('error: ', error);
}
}
async delPrompt(req: Request, body){
const {id} = body
if(!id) {
async delPrompt(req: Request, body) {
const { id } = body;
if (!id) {
throw new HttpException('非法操作!', HttpStatus.BAD_REQUEST);
}
return await this.mjPromptsEntity.delete({id})
return await this.mjPromptsEntity.delete({ id });
}
async queryPrompt(){
async queryPrompt() {
return await this.mjPromptsEntity.find({
order: { order: 'DESC' },
})
});
}
async proxyImg(params){
const { url } = params
if(!url) return
async proxyImg(params) {
const { url } = params;
if (!url) return;
const response = await axios.get(url, { responseType: 'arraybuffer' });
const base64 = Buffer.from(response.data).toString('base64');
return base64
return base64;
}
}

View File

@@ -56,4 +56,6 @@ export class SetModelDto {
//设置token计费
@ApiProperty({ example: true, description: '是否使用token计费', required: false })
isTokenBased: boolean;
@ApiProperty({ example: true, description: 'token计费比例', required: false })
tokenFeeRatio: number;
}

View File

@@ -66,4 +66,7 @@ export class ModelsEntity extends BaseEntity {
@Column({ comment: '是否使用token计费: 0:不是 1是', default: 0 })
isTokenBased: boolean;
@Column({ comment: 'token计费比例', default: 0 })
tokenFeeRatio: number;
}

View File

@@ -50,7 +50,7 @@ export class QueueService implements OnApplicationBootstrap {
/* 绘图和图生图扣除余额4 */
this.jobIds.push(job.id);
/* 扣费 */
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return true;
}
@@ -69,7 +69,7 @@ export class QueueService implements OnApplicationBootstrap {
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
/* 扣费 */
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
this.jobIds.push(job.id);
return;
}
@@ -83,7 +83,7 @@ export class QueueService implements OnApplicationBootstrap {
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
/* 扣费 */
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
@@ -95,7 +95,7 @@ export class QueueService implements OnApplicationBootstrap {
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
@@ -107,7 +107,7 @@ export class QueueService implements OnApplicationBootstrap {
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
@@ -119,7 +119,7 @@ export class QueueService implements OnApplicationBootstrap {
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
this.jobIds.push(job.id);
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
return;
}
}

View File

@@ -300,8 +300,8 @@ export class UserBalanceService {
/* 记录修改使用的token */
const updateBalance = { [updateKey]: b[updateKey] - amount < 0 ? 0 : b[updateKey] - amount, [useKey]: b[useKey] + UseAmount };
/* 记录修改使用的次数 mj不需要 */
useKey === 'useModel3Token' && (updateBalance['useModel3Count'] = b['useModel3Count'] + 1);
useKey === 'useModel4Token' && (updateBalance['useModel4Count'] = b['useModel4Count'] + 1);
useKey === 'useModel3Token' && (updateBalance['useModel3Count'] = b['useModel3Count'] + amount);
useKey === 'useModel4Token' && (updateBalance['useModel4Count'] = b['useModel4Count'] + amount);
const result = await this.userBalanceEntity.update({ userId }, updateBalance);
if (result.affected === 0) {
throw new HttpException('消费余额失败!', HttpStatus.BAD_REQUEST);
@@ -552,7 +552,9 @@ export class UserBalanceService {
}
/* MJ绘画失败退款 */
async refundMjBalance(userId, amount) {}
async refundMjBalance(userId, amount) {
return await this.deductFromBalance(userId, 'mjDraw', -amount);
}
/* V1.5升级将旧版本余额并入到新表 */
async upgradeBalance() {