mirror of
https://github.com/xiaoyiweb/YiAi.git
synced 2025-11-14 05:03:46 +08:00
新增gpt-4V上传 修复部分bug
This commit is contained in:
BIN
service/src/.DS_Store
vendored
BIN
service/src/.DS_Store
vendored
Binary file not shown.
@@ -32,7 +32,7 @@ async function bootstrap() {
|
||||
|
||||
createSwagger(app);
|
||||
const server = await app.listen(PORT, () => {
|
||||
Logger.log(`服务启动成功: http://localhost:${PORT}/nineai/swagger/docs 作者:小易 QQ:805239273`, 'Main');
|
||||
Logger.log(`服务启动成功: http://localhost:${PORT}/nineai/swagger/docs`, 'Main');
|
||||
});
|
||||
server.timeout = 5 * 60 * 1000;
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ export class ChatLogEntity extends BaseEntity {
|
||||
@Column({ comment: '图片信息的string', nullable: true, type: 'text' })
|
||||
fileInfo: string;
|
||||
|
||||
@Column({ comment: '上传图片的信息', nullable: true, type: 'text' })
|
||||
imageUrl: string;
|
||||
|
||||
@Column({ comment: 'role system user assistant', nullable: true })
|
||||
role: string;
|
||||
|
||||
|
||||
@@ -211,7 +211,7 @@ export class ChatLogService {
|
||||
}
|
||||
const list = await this.chatLogEntity.find({ where });
|
||||
return list.map((item) => {
|
||||
const { prompt, role, answer, createdAt, model, conversationOptions, requestOptions, id } = item;
|
||||
const { prompt, role, answer, createdAt, model, conversationOptions, requestOptions, id,imageUrl} = item;
|
||||
let parseConversationOptions: any = null
|
||||
let parseRequestOptions: any = null
|
||||
try {
|
||||
@@ -228,6 +228,8 @@ export class ChatLogService {
|
||||
error: false,
|
||||
conversationOptions: parseConversationOptions,
|
||||
requestOptions: parseRequestOptions,
|
||||
imageUrl,
|
||||
model
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
/* 不同场景会变更其信息 */
|
||||
let setSystemMessage = systemMessage;
|
||||
const { parentMessageId } = options;
|
||||
const { prompt } = body;
|
||||
const { prompt ,imageUrl,model:activeModel} = body;
|
||||
const { groupId, usingNetwork } = options;
|
||||
// const { model = 3 } = options;
|
||||
/* 获取当前对话组的详细配置信息 */
|
||||
@@ -260,6 +260,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
userId: req.user.id,
|
||||
type: DeductionKey.CHAT_TYPE,
|
||||
prompt,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
answer: '',
|
||||
promptTokens: prompt_tokens,
|
||||
completionTokens: 0,
|
||||
@@ -318,6 +320,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
const { context: messagesHistory } = await this.nineStore.buildMessageFromParentMessageId(usingNetwork ? netWorkPrompt : prompt, {
|
||||
parentMessageId,
|
||||
systemMessage,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
maxModelToken: maxToken,
|
||||
maxResponseTokens: maxTokenRes,
|
||||
maxRounds: addOneIfOdd(rounds),
|
||||
@@ -328,6 +332,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
maxTokenRes,
|
||||
apiKey: modelKey,
|
||||
model,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
temperature,
|
||||
proxyUrl: proxyResUrl,
|
||||
onProgress: (chat) => {
|
||||
@@ -386,6 +392,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
role: 'user',
|
||||
name: undefined,
|
||||
usage: null,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
parentMessageId: parentMessageId,
|
||||
conversationId: response?.conversationId,
|
||||
};
|
||||
@@ -449,6 +457,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
userId: req.user.id,
|
||||
type: DeductionKey.CHAT_TYPE,
|
||||
prompt,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
answer: '',
|
||||
promptTokens: prompt_tokens,
|
||||
completionTokens: 0,
|
||||
|
||||
@@ -102,7 +102,11 @@ export function sendMessageFromOpenAi(messagesHistory, inputs ){
|
||||
|
||||
|
||||
export function getTokenCount(text: string) {
|
||||
if(!text) return 0;
|
||||
if (!text) return 0;
|
||||
// 确保text是字符串类型
|
||||
if (typeof text !== 'string') {
|
||||
text = String(text);
|
||||
}
|
||||
text = text.replace(/<\|endoftext\|>/g, '')
|
||||
return tokenizer.encode(text).length
|
||||
}
|
||||
|
||||
@@ -1,41 +1,44 @@
|
||||
import Keyv from 'keyv'
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { get_encoding } from '@dqbd/tiktoken'
|
||||
import Keyv from 'keyv';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { get_encoding } from '@dqbd/tiktoken';
|
||||
import { Logger } from '@nestjs/common';
|
||||
|
||||
const tokenizer = get_encoding('cl100k_base')
|
||||
|
||||
|
||||
export type Role = 'user' | 'assistant' | 'function'
|
||||
const tokenizer = get_encoding('cl100k_base');
|
||||
|
||||
export type Role = 'user' | 'assistant' | 'function';
|
||||
|
||||
interface Options {
|
||||
store: Keyv
|
||||
namespace: string
|
||||
expires?: number
|
||||
store: Keyv;
|
||||
namespace: string;
|
||||
expires?: number;
|
||||
}
|
||||
|
||||
export interface MessageInfo {
|
||||
id: string
|
||||
text: string
|
||||
role: Role
|
||||
name?: string
|
||||
id: string;
|
||||
text: string;
|
||||
role: Role;
|
||||
name?: string;
|
||||
imageUrl?: string;
|
||||
activeModel?: string;
|
||||
usage: {
|
||||
prompt_tokens?: number
|
||||
completion_tokens?: number
|
||||
total_tokens?: number
|
||||
}
|
||||
parentMessageId?: string
|
||||
conversationId?: string
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
};
|
||||
parentMessageId?: string;
|
||||
conversationId?: string;
|
||||
}
|
||||
|
||||
export interface BuildMessageOptions {
|
||||
systemMessage?: string
|
||||
parentMessageId: string
|
||||
maxRounds?: number
|
||||
maxModelToken?: number
|
||||
maxResponseTokens?: number
|
||||
name?: string
|
||||
systemMessage?: string;
|
||||
parentMessageId: string;
|
||||
maxRounds?: number;
|
||||
maxModelToken?: number;
|
||||
maxResponseTokens?: number;
|
||||
name?: string;
|
||||
imageUrl?: string;
|
||||
activeModel?: string;
|
||||
model?: string;
|
||||
}
|
||||
|
||||
// export interface BuildMessageRes {
|
||||
@@ -43,12 +46,12 @@ export interface BuildMessageOptions {
|
||||
// numTokens: number
|
||||
// maxTokens: number
|
||||
// }
|
||||
export type BuildMessageRes = any[]
|
||||
export type BuildMessageRes = any[];
|
||||
|
||||
export interface NineStoreInterface {
|
||||
getData(id: string): Promise<string>;
|
||||
setData(message: MessageInfo, expires?: number): Promise<void>;
|
||||
getUuid(): string
|
||||
getUuid(): string;
|
||||
buildMessageFromParentMessageId(string, opt?: BuildMessageOptions): Promise<any>;
|
||||
}
|
||||
|
||||
@@ -58,115 +61,165 @@ export class NineStore implements NineStoreInterface {
|
||||
private expires: number;
|
||||
|
||||
constructor(options: Options) {
|
||||
const { store, namespace, expires } = this.formatOptions(options)
|
||||
this.store = store
|
||||
this.namespace = namespace
|
||||
this.expires = expires
|
||||
const { store, namespace, expires } = this.formatOptions(options);
|
||||
this.store = store;
|
||||
this.namespace = namespace;
|
||||
this.expires = expires;
|
||||
}
|
||||
|
||||
public formatOptions(options: Options){
|
||||
const { store, expires = 1000 * 60 * 60 * 24 * 3, namespace = 'chat'} = options
|
||||
return { store, namespace, expires }
|
||||
public formatOptions(options: Options) {
|
||||
const { store, expires = 1000 * 60 * 60 * 24 * 3, namespace = 'chat' } = options;
|
||||
return { store, namespace, expires };
|
||||
}
|
||||
|
||||
public generateKey(key){
|
||||
return this.namespace ? `${this.namespace}-${key}` : key
|
||||
public generateKey(key: any) {
|
||||
return this.namespace ? `${this.namespace}-${key}` : key;
|
||||
}
|
||||
|
||||
public async getData(id: string ): Promise<any> {
|
||||
const res = await this.store.get(id)
|
||||
return res
|
||||
public async getData(id: string): Promise<any> {
|
||||
const res = await this.store.get(id);
|
||||
return res;
|
||||
}
|
||||
|
||||
public async setData(message, expires = this.expires): Promise<void> {
|
||||
await this.store.set(message.id, message, expires)
|
||||
public async setData(message: MessageInfo, expires = this.expires): Promise<void> {
|
||||
await this.store.set(message.id, message, expires);
|
||||
}
|
||||
|
||||
/**
|
||||
* @desc 通过传入prompt和parentMessageId 递归往上拿到历史记录组装为模型需要的上下文、
|
||||
* 可以传入maxRounds限制最大轮次的对话 传入maxModelToken, maxResponseTokens 则通过计算上下文占用的token计算出最大容量
|
||||
*/
|
||||
public async buildMessageFromParentMessageId(text: string, options: BuildMessageOptions){
|
||||
let { maxRounds, maxModelToken, maxResponseTokens, systemMessage = '', name } = options
|
||||
let { parentMessageId } = options
|
||||
let messages = []
|
||||
let nextNumTokensEstimate = 0
|
||||
public async buildMessageFromParentMessageId(text: string, options: BuildMessageOptions) {
|
||||
let { maxRounds, maxModelToken, maxResponseTokens, systemMessage = '', name, imageUrl, model, activeModel } = options;
|
||||
let { parentMessageId } = options;
|
||||
let messages = [];
|
||||
let nextNumTokensEstimate = 0;
|
||||
// messages.push({ role: 'system', content: systemMessage, name })
|
||||
if (systemMessage) {
|
||||
messages.push({ role: 'system', content: systemMessage })
|
||||
const specialModels = ['gemini-pro', 'ERNIE', 'qwen', 'SparkDesk', 'hunyuan'];
|
||||
const isSpecialModel = activeModel && specialModels.some((specialModel) => activeModel.includes(specialModel));
|
||||
if (isSpecialModel) {
|
||||
messages.push({ role: 'user', content: systemMessage, name });
|
||||
messages.push({ role: 'assistant', content: '好的', name });
|
||||
} else {
|
||||
messages.push({ role: 'system', content: systemMessage, name });
|
||||
}
|
||||
}
|
||||
const systemMessageOffset = messages.length
|
||||
let round = 0
|
||||
let nextMessages = text ? messages.concat([{ role: 'user', content: text, name }]) : messages
|
||||
do {
|
||||
// let parentId = '1bf30262-8f25-4a03-88ad-9d42d55e6f0b'
|
||||
/* 没有parentMessageId就没有历史 直接返回 */
|
||||
if(!parentMessageId){
|
||||
break;
|
||||
const systemMessageOffset = messages.length;
|
||||
let round = 0;
|
||||
// 特殊处理 gpt-4-vision-preview 模型
|
||||
if (activeModel === 'gpt-4-vision-preview' && imageUrl) {
|
||||
const content = [
|
||||
{
|
||||
type: 'text',
|
||||
text: text,
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: imageUrl,
|
||||
},
|
||||
},
|
||||
];
|
||||
messages.push({ role: 'user', content: content, name });
|
||||
} else {
|
||||
// 处理 gpt-4-all 模型
|
||||
if (model === 'gpt-4-all' && imageUrl) {
|
||||
text = imageUrl + '\n' + text;
|
||||
}
|
||||
const parentMessage = await this.getData(parentMessageId)
|
||||
messages.push({ role: 'user', content: text, name });
|
||||
}
|
||||
// Logger.debug(`发送的参数:${messages}`)
|
||||
|
||||
if(!parentMessage){
|
||||
let nextMessages = messages;
|
||||
do {
|
||||
// let parentId = '1bf30262-8f25-4a03-88ad-9d42d55e6f0b'
|
||||
/* 没有parentMessageId就没有历史 直接返回 */
|
||||
if (!parentMessageId) {
|
||||
break;
|
||||
}
|
||||
const { text, name, role } = parentMessage
|
||||
const parentMessage = await this.getData(parentMessageId);
|
||||
if (!parentMessage) {
|
||||
break;
|
||||
}
|
||||
const { text, name, role, imageUrl } = parentMessage;
|
||||
let content = text; // 默认情况下使用text作为content
|
||||
|
||||
// 特别处理包含 imageUrl 的消息
|
||||
if (role === 'user' && imageUrl) {
|
||||
if (activeModel === 'gpt-4-vision-preview') {
|
||||
content = [
|
||||
{ type: 'text', text: text },
|
||||
{ type: 'image_url', image_url: { url: imageUrl } },
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
/* 将本轮消息插入到列表中 */
|
||||
nextMessages = nextMessages.slice(0, systemMessageOffset).concat([
|
||||
{ role, content: text, name },
|
||||
...nextMessages.slice(systemMessageOffset)
|
||||
])
|
||||
round++
|
||||
{ role, content, name }, // 使用调整后的content
|
||||
...nextMessages.slice(systemMessageOffset),
|
||||
]);
|
||||
|
||||
// Logger.debug(`nextMessages:${JSON.stringify(nextMessages, null, 2)}`);
|
||||
|
||||
round++;
|
||||
/* 如果超出了限制的最大轮次 就退出 不包含本次发送的本身 */
|
||||
if(maxRounds && round >= maxRounds){
|
||||
if (maxRounds && round >= maxRounds) {
|
||||
break;
|
||||
}
|
||||
/* 如果传入maxModelToken, maxResponseTokens 则判断是否超过边界 */
|
||||
if(maxModelToken && maxResponseTokens){
|
||||
const maxNumTokens = maxModelToken - maxResponseTokens // 模型最大token限制减去限制回复剩余空间
|
||||
if (maxModelToken && maxResponseTokens) {
|
||||
const maxNumTokens = maxModelToken - maxResponseTokens; // 模型最大token限制减去限制回复剩余空间
|
||||
/* 当前的对话历史列表合并的总token容量 */
|
||||
nextNumTokensEstimate = await this._getTokenCount(nextMessages)
|
||||
nextNumTokensEstimate = await this._getTokenCount(nextMessages);
|
||||
/* 200是添加的一个安全区间 防止少量超过 待优化 */
|
||||
const isValidPrompt = nextNumTokensEstimate + 200 <= maxNumTokens
|
||||
const isValidPrompt = nextNumTokensEstimate + 200 <= maxNumTokens;
|
||||
/* 如果大于这个区间了说明本轮加入之后导致超过限制、则递归删除头部的对话轮次来保证不出边界 */
|
||||
if(!isValidPrompt){
|
||||
nextMessages = this._recursivePruning(nextMessages, maxNumTokens, systemMessage)
|
||||
if (!isValidPrompt) {
|
||||
nextMessages = this._recursivePruning(nextMessages, maxNumTokens, systemMessage);
|
||||
}
|
||||
}
|
||||
parentMessageId = parentMessage.parentMessageId
|
||||
parentMessageId = parentMessage.parentMessageId;
|
||||
} while (true);
|
||||
const maxTokens = Math.max(
|
||||
1,
|
||||
Math.min(maxModelToken - nextNumTokensEstimate, maxResponseTokens)
|
||||
)
|
||||
const maxTokens = Math.max(1, Math.min(maxModelToken - nextNumTokensEstimate, maxResponseTokens));
|
||||
// Logger.debug(`本轮调用:模型:${model}`)
|
||||
console.log('本次携带上下文的长度',nextMessages.length, nextNumTokensEstimate )
|
||||
return { context: nextMessages, round: nextMessages.length, historyToken:nextNumTokensEstimate }
|
||||
console.log('本次携带上下文的长度', nextMessages.length, nextNumTokensEstimate);
|
||||
return { context: nextMessages, round: nextMessages.length, historyToken: nextNumTokensEstimate };
|
||||
}
|
||||
|
||||
protected _getTokenCount(messages: any[]) {
|
||||
let text = messages.reduce( (pre: string, cur: any) => {
|
||||
return pre+=cur.content
|
||||
}, '')
|
||||
text = text.replace(/<\|endoftext\|>/g, '')
|
||||
return tokenizer.encode(text).length
|
||||
}
|
||||
let text = messages.reduce((pre: string, cur: any) => {
|
||||
// 检查cur.content是否为数组
|
||||
if (Array.isArray(cur.content)) {
|
||||
// 提取并连接数组中的文本元素
|
||||
const contentText = cur.content
|
||||
.filter((item: { type: string }) => item.type === 'text')
|
||||
.map((item: { text: any }) => item.text)
|
||||
.join(' ');
|
||||
return pre + contentText;
|
||||
} else {
|
||||
// 如果不是数组,则直接添加
|
||||
return pre + (cur.content || '');
|
||||
}
|
||||
}, '');
|
||||
|
||||
text = text.replace(/<\|endoftext\|>/g, '');
|
||||
return tokenizer.encode(text).length;
|
||||
}
|
||||
|
||||
/* 递归删除 当token超过模型限制容量 删除到在限制区域内 */
|
||||
protected _recursivePruning(
|
||||
messages: MessageInfo[],
|
||||
maxNumTokens: number,
|
||||
systemMessage: string
|
||||
) {
|
||||
const currentTokens = this._getTokenCount(messages)
|
||||
protected _recursivePruning(messages: MessageInfo[], maxNumTokens: number, systemMessage: string) {
|
||||
const currentTokens = this._getTokenCount(messages);
|
||||
if (currentTokens <= maxNumTokens) {
|
||||
return messages
|
||||
return messages;
|
||||
}
|
||||
/* 有系统预设则跳过第一条删除 没有则直接删除 */
|
||||
messages.splice(systemMessage ? 1 : 0, 1)
|
||||
return this._recursivePruning(messages, maxNumTokens, systemMessage)
|
||||
messages.splice(systemMessage ? 1 : 0, 1);
|
||||
return this._recursivePruning(messages, maxNumTokens, systemMessage);
|
||||
}
|
||||
|
||||
public getUuid(){
|
||||
return uuidv4()
|
||||
public getUuid() {
|
||||
return uuidv4();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ export class MidjourneyService {
|
||||
private redisCacheService: RedisCacheService,
|
||||
) {}
|
||||
|
||||
|
||||
private lockPrompt = [];
|
||||
|
||||
/* 睡眠 xs */
|
||||
@@ -111,6 +110,9 @@ export class MidjourneyService {
|
||||
await this.updateDrawData(jobData, drawRes);
|
||||
/* 存完解锁当前文件 */
|
||||
this.lockPrompt = this.lockPrompt.filter((item) => item !== drawInfo.randomDrawId);
|
||||
|
||||
/* 只有在画成功后才扣分*/
|
||||
this.drawSuccess(jobData);
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -160,16 +162,16 @@ export class MidjourneyService {
|
||||
const { id, content, channel_id, attachments = [], timestamp, durationSpent } = drawRes;
|
||||
const { filename, url, proxy_url, width, height, size } = attachments[0];
|
||||
/* 将图片存入cos */
|
||||
const mjNotSaveImg = await this.globalConfigService.getConfigs(['mjNotSaveImg'])
|
||||
let cosUrl = ''
|
||||
if(!Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0){
|
||||
const mjNotSaveImg = await this.globalConfigService.getConfigs(['mjNotSaveImg']);
|
||||
let cosUrl = '';
|
||||
if (!Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0) {
|
||||
Logger.debug(`------> 开始上传图片!!!`, 'MidjourneyService');
|
||||
const startDate = new Date();
|
||||
cosUrl = await this.uploadService.uploadFileFromUrl({ filename, url });
|
||||
const endDate = new Date();
|
||||
Logger.debug(`本次图片上传耗时为${(endDate.getTime() - startDate.getTime()) / 1000}秒`, 'MidjourneyService');
|
||||
}else{
|
||||
console.log('本次不存图片了')
|
||||
} else {
|
||||
console.log('本次不存图片了');
|
||||
}
|
||||
/* 记录当前图片存储方式 方便后续对不同平台图片压缩 */
|
||||
const cosType = await this.uploadService.getUploadType();
|
||||
@@ -181,11 +183,11 @@ export class MidjourneyService {
|
||||
fileInfo: JSON.stringify({ width, height, size, filename, cosUrl, cosType }),
|
||||
extend: this.removeEmoji(JSON.stringify(drawRes)),
|
||||
durationSpent,
|
||||
isSaveImg: !Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0,
|
||||
isSaveImg: !Number(mjNotSaveImg) || Number(mjNotSaveImg) === 0,
|
||||
};
|
||||
await this.midjourneyEntity.update({ id: jobData.id }, drawInfo);
|
||||
} catch (error) {
|
||||
console.log('TODO->存储图片失败, ', jobData,error);
|
||||
console.log('TODO->存储图片失败, ', jobData, error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -734,17 +736,15 @@ export class MidjourneyService {
|
||||
take: size,
|
||||
skip: (page - 1) * size,
|
||||
});
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
|
||||
rows.forEach((item: any) => {
|
||||
try {
|
||||
const { extend, isSaveImg, fileInfo } = item;
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
|
||||
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
|
||||
item.originUrl = originUrl
|
||||
} catch (error) {
|
||||
|
||||
}
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
|
||||
item.originUrl = originUrl;
|
||||
} catch (error) {}
|
||||
});
|
||||
const countQueue = await this.midjourneyEntity.count({ where: { isDelete: 0, status: In([1, 2]) } });
|
||||
const data: any = { rows: formatCreateOrUpdateDate(rows), count, countQueue };
|
||||
@@ -755,18 +755,18 @@ export class MidjourneyService {
|
||||
}
|
||||
|
||||
/* 格式化fileinfo 对于不同平台的图片进行压缩 */
|
||||
formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl) {
|
||||
formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl) {
|
||||
if (!fileInfo) return {};
|
||||
let parseFileInfo: any = null
|
||||
let parseFileInfo: any = null;
|
||||
try {
|
||||
parseFileInfo = JSON.parse(fileInfo);
|
||||
} catch (error) {
|
||||
parseFileInfo = null
|
||||
parseFileInfo = null;
|
||||
}
|
||||
if(!parseFileInfo) return;
|
||||
if (!parseFileInfo) return;
|
||||
const { url, filename, size, cosUrl, width, height } = parseFileInfo;
|
||||
const targetSize = 310; // 目标宽度或高度
|
||||
|
||||
|
||||
// TODO判断逻辑有误 腾讯云会导致也判断为 chevereto 更换判断规则
|
||||
const imgType = cosUrl.includes('cos') ? 'tencent' : cosUrl.includes('oss') ? 'ali' : 'chevereto';
|
||||
let compress;
|
||||
@@ -786,10 +786,10 @@ export class MidjourneyService {
|
||||
}
|
||||
parseFileInfo.thumbImg = thumbImg;
|
||||
/* 如果配置了不存储图片 则 isSaceImg 为false的则需要使用反代地址拼接 */
|
||||
if(!isSaveImg){
|
||||
const proxyImgUrl = `${mjProxyImgUrl}/mj/pipe?url=${originUrl}`
|
||||
parseFileInfo.thumbImg = proxyImgUrl
|
||||
parseFileInfo.cosUrl = proxyImgUrl
|
||||
if (!isSaveImg) {
|
||||
const proxyImgUrl = `${mjProxyImgUrl}/mj/pipe?url=${originUrl}`;
|
||||
parseFileInfo.thumbImg = proxyImgUrl;
|
||||
parseFileInfo.cosUrl = proxyImgUrl;
|
||||
}
|
||||
return parseFileInfo;
|
||||
}
|
||||
@@ -859,8 +859,8 @@ export class MidjourneyService {
|
||||
// return;
|
||||
// }
|
||||
const count = await this.midjourneyEntity.count({ where: { userId: id, isDelete: 0, status: In([1, 2]) } });
|
||||
const mjLimitCount = await this.globalConfigService.getConfigs(['mjLimitCount'])
|
||||
const max = mjLimitCount ? Number(mjLimitCount) : 2
|
||||
const mjLimitCount = await this.globalConfigService.getConfigs(['mjLimitCount']);
|
||||
const max = mjLimitCount ? Number(mjLimitCount) : 2;
|
||||
if (count >= max) {
|
||||
throw new HttpException(`当前管理员限制单用户同时最多能执行${max}个任务`, HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
@@ -870,11 +870,21 @@ export class MidjourneyService {
|
||||
async drawFailed(jobData) {
|
||||
const { id, userId, action } = jobData;
|
||||
/* 退还余额 放大图片(类型2)是1 其他都是4 */
|
||||
const amount = action === 2 ? 1 : 4;
|
||||
await this.userBalanceService.refundMjBalance(userId, amount);
|
||||
// const amount = action === 2 ? 1 : 4;
|
||||
// await this.userBalanceService.refundMjBalance(userId, amount);
|
||||
await this.midjourneyEntity.update({ id }, { status: 4 });
|
||||
}
|
||||
|
||||
/* 绘图成功扣费 */
|
||||
async drawSuccess(jobData) {
|
||||
const { id, userId, action } = jobData;
|
||||
/* 扣除余额 放大图片(类型2)是1 其他都是4 */
|
||||
const amount = action === 2 ? 1 : 4;
|
||||
Logger.debug(`绘画完成,执行扣费,扣除费用:${amount}积分。`);
|
||||
await this.userBalanceService.refundMjBalance(userId, -amount);
|
||||
await this.midjourneyEntity.update({ id }, { status: 3 });
|
||||
}
|
||||
|
||||
/* 获取绘画列表 */
|
||||
async getList(params: GetListDto) {
|
||||
const { page = 1, size = 20, rec, userId, status } = params;
|
||||
@@ -902,17 +912,15 @@ export class MidjourneyService {
|
||||
skip: (page - 1) * size,
|
||||
select: ['fileInfo', 'extend', 'prompt', 'createdAt', 'id', 'extend', 'fullPrompt', 'rec', 'isSaveImg'],
|
||||
});
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
|
||||
rows.forEach((item: any) => {
|
||||
try {
|
||||
const { extend, isSaveImg, fileInfo } = item;
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
|
||||
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
|
||||
item.originUrl = originUrl
|
||||
} catch (error) {
|
||||
|
||||
}
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
|
||||
item.originUrl = originUrl;
|
||||
} catch (error) {}
|
||||
});
|
||||
|
||||
if (Number(size) === 999) {
|
||||
@@ -931,10 +939,10 @@ export class MidjourneyService {
|
||||
}
|
||||
|
||||
/* */
|
||||
async getFullPrompt(id: number){
|
||||
const m = await this.midjourneyEntity.findOne({where: {id}})
|
||||
if(!m) return ''
|
||||
const { fullPrompt } = m
|
||||
async getFullPrompt(id: number) {
|
||||
const m = await this.midjourneyEntity.findOne({ where: { id } });
|
||||
if (!m) return '';
|
||||
const { fullPrompt } = m;
|
||||
return fullPrompt;
|
||||
}
|
||||
|
||||
@@ -953,27 +961,25 @@ export class MidjourneyService {
|
||||
skip: (page - 1) * size,
|
||||
});
|
||||
|
||||
const userIds = rows.map((item: any) => item.userId).filter( id => id < 100000);
|
||||
const userIds = rows.map((item: any) => item.userId).filter((id) => id < 100000);
|
||||
const userInfos = await this.userEntity.find({ where: { id: In(userIds) }, select: ['id', 'username', 'avatar', 'email'] });
|
||||
rows.forEach((item: any) => {
|
||||
item.userInfo = userInfos.find((user) => user.id === item.userId);
|
||||
});
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl'])
|
||||
const mjProxyImgUrl = await this.globalConfigService.getConfigs(['mjProxyImgUrl']);
|
||||
rows.forEach((item: any) => {
|
||||
try {
|
||||
const { extend, isSaveImg, fileInfo } = item;
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url
|
||||
const originUrl = JSON.parse(extend)?.attachments[0]?.url;
|
||||
item.fileInfo = this.formatFileInfo(fileInfo, isSaveImg, mjProxyImgUrl, originUrl);
|
||||
// item.isGroup = JSON.parse(extend)?.components[0]?.components.length === 5;
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === "U1";
|
||||
item.originUrl = originUrl
|
||||
} catch (error) {
|
||||
|
||||
}
|
||||
item.isGroup = JSON.parse(extend)?.components[0]?.components[0].label === 'U1';
|
||||
item.originUrl = originUrl;
|
||||
} catch (error) {}
|
||||
});
|
||||
if (req.user.role !== 'super') {
|
||||
rows.forEach((item: any) => {
|
||||
if(item.userInfo && item.userInfo.email){
|
||||
if (item.userInfo && item.userInfo.email) {
|
||||
item.userInfo.email = item.userInfo.email.replace(/(.{2}).+(.{2}@.+)/, '$1****$2');
|
||||
}
|
||||
});
|
||||
@@ -1021,38 +1027,38 @@ export class MidjourneyService {
|
||||
}
|
||||
}
|
||||
|
||||
async setPrompt(req: Request, body){
|
||||
async setPrompt(req: Request, body) {
|
||||
try {
|
||||
const { prompt, status, isCarryParams, title, order, id, aspect } = body
|
||||
if(id){
|
||||
return await this.mjPromptsEntity.update({id}, {prompt, status, isCarryParams, order, aspect})
|
||||
}else{
|
||||
return await this.mjPromptsEntity.save({prompt, status, isCarryParams, title, order, aspect})
|
||||
}
|
||||
const { prompt, status, isCarryParams, title, order, id, aspect } = body;
|
||||
if (id) {
|
||||
return await this.mjPromptsEntity.update({ id }, { prompt, status, isCarryParams, order, aspect });
|
||||
} else {
|
||||
return await this.mjPromptsEntity.save({ prompt, status, isCarryParams, title, order, aspect });
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('error: ', error);
|
||||
}
|
||||
}
|
||||
|
||||
async delPrompt(req: Request, body){
|
||||
const {id} = body
|
||||
if(!id) {
|
||||
async delPrompt(req: Request, body) {
|
||||
const { id } = body;
|
||||
if (!id) {
|
||||
throw new HttpException('非法操作!', HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
return await this.mjPromptsEntity.delete({id})
|
||||
return await this.mjPromptsEntity.delete({ id });
|
||||
}
|
||||
|
||||
async queryPrompt(){
|
||||
async queryPrompt() {
|
||||
return await this.mjPromptsEntity.find({
|
||||
order: { order: 'DESC' },
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
async proxyImg(params){
|
||||
const { url } = params
|
||||
if(!url) return
|
||||
async proxyImg(params) {
|
||||
const { url } = params;
|
||||
if (!url) return;
|
||||
const response = await axios.get(url, { responseType: 'arraybuffer' });
|
||||
const base64 = Buffer.from(response.data).toString('base64');
|
||||
return base64
|
||||
return base64;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,4 +56,6 @@ export class SetModelDto {
|
||||
//设置token计费
|
||||
@ApiProperty({ example: true, description: '是否使用token计费', required: false })
|
||||
isTokenBased: boolean;
|
||||
@ApiProperty({ example: true, description: 'token计费比例', required: false })
|
||||
tokenFeeRatio: number;
|
||||
}
|
||||
|
||||
@@ -66,4 +66,7 @@ export class ModelsEntity extends BaseEntity {
|
||||
|
||||
@Column({ comment: '是否使用token计费: 0:不是 1:是', default: 0 })
|
||||
isTokenBased: boolean;
|
||||
|
||||
@Column({ comment: 'token计费比例', default: 0 })
|
||||
tokenFeeRatio: number;
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
/* 绘图和图生图扣除余额4 */
|
||||
this.jobIds.push(job.id);
|
||||
/* 扣费 */
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
/* 扣费 */
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
|
||||
this.jobIds.push(job.id);
|
||||
return;
|
||||
}
|
||||
@@ -83,7 +83,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
/* 扣费 */
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,8 +300,8 @@ export class UserBalanceService {
|
||||
/* 记录修改使用的token */
|
||||
const updateBalance = { [updateKey]: b[updateKey] - amount < 0 ? 0 : b[updateKey] - amount, [useKey]: b[useKey] + UseAmount };
|
||||
/* 记录修改使用的次数 mj不需要 */
|
||||
useKey === 'useModel3Token' && (updateBalance['useModel3Count'] = b['useModel3Count'] + 1);
|
||||
useKey === 'useModel4Token' && (updateBalance['useModel4Count'] = b['useModel4Count'] + 1);
|
||||
useKey === 'useModel3Token' && (updateBalance['useModel3Count'] = b['useModel3Count'] + amount);
|
||||
useKey === 'useModel4Token' && (updateBalance['useModel4Count'] = b['useModel4Count'] + amount);
|
||||
const result = await this.userBalanceEntity.update({ userId }, updateBalance);
|
||||
if (result.affected === 0) {
|
||||
throw new HttpException('消费余额失败!', HttpStatus.BAD_REQUEST);
|
||||
@@ -552,7 +552,9 @@ export class UserBalanceService {
|
||||
}
|
||||
|
||||
/* MJ绘画失败退款 */
|
||||
async refundMjBalance(userId, amount) {}
|
||||
async refundMjBalance(userId, amount) {
|
||||
return await this.deductFromBalance(userId, 'mjDraw', -amount);
|
||||
}
|
||||
|
||||
/* V1.5升级将旧版本余额并入到新表 */
|
||||
async upgradeBalance() {
|
||||
|
||||
Reference in New Issue
Block a user