mirror of
https://github.com/xiaoyiweb/YiAi.git
synced 2025-11-13 20:53:47 +08:00
2.5版本 增加dall-e 优化mj 对接mj-plus
This commit is contained in:
BIN
service/src/.DS_Store
vendored
BIN
service/src/.DS_Store
vendored
Binary file not shown.
@@ -2,8 +2,8 @@ import { UploadService } from './../upload/upload.service';
|
||||
import { UserService } from './../user/user.service';
|
||||
import { ConfigService } from 'nestjs-config';
|
||||
import { HttpException, HttpStatus, Injectable, OnModuleInit, Logger } from '@nestjs/common';
|
||||
import type { ChatGPTAPIOptions, ChatMessage, SendMessageOptions } from 'chatgpt-nine-ai';
|
||||
import { Request, Response } from 'express';
|
||||
import type { ChatGPTAPIOptions, ChatMessage, SendMessageOptions } from 'chatgpt-ai-web';
|
||||
import e, { Request, Response } from 'express';
|
||||
import { OpenAiErrorCodeMessage } from '@/common/constants/errorMessage.constant';
|
||||
import {
|
||||
compileNetwork,
|
||||
@@ -97,7 +97,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
};
|
||||
|
||||
async onModuleInit() {
|
||||
let chatgpt = await importDynamic('chatgpt-nine-ai');
|
||||
let chatgpt = await importDynamic('chatgpt-ai-web');
|
||||
let KeyvRedis = await importDynamic('@keyv/redis');
|
||||
let Keyv = await importDynamic('keyv');
|
||||
chatgpt = chatgpt?.default ? chatgpt.default : chatgpt;
|
||||
@@ -165,7 +165,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
/* 不同场景会变更其信息 */
|
||||
let setSystemMessage = systemMessage;
|
||||
const { parentMessageId } = options;
|
||||
const { prompt ,imageUrl,model:activeModel} = body;
|
||||
const { prompt, imageUrl, model: activeModel } = body;
|
||||
const { groupId, usingNetwork } = options;
|
||||
// const { model = 3 } = options;
|
||||
/* 获取当前对话组的详细配置信息 */
|
||||
@@ -184,7 +184,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
throw new HttpException('当前流程所需要的模型已被管理员下架、请联系管理员上架专属模型!', HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
|
||||
const { deduct, isTokenBased, deductType, key: modelKey, secret, modelName, id: keyId, accessToken } = currentRequestModelKey;
|
||||
const { deduct, isTokenBased, tokenFeeRatio, deductType, key: modelKey, secret, modelName, id: keyId, accessToken } = currentRequestModelKey;
|
||||
/* 用户状态检测 */
|
||||
await this.userService.checkUserStatus(req.user);
|
||||
/* 用户余额检测 */
|
||||
@@ -260,7 +260,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
userId: req.user.id,
|
||||
type: DeductionKey.CHAT_TYPE,
|
||||
prompt,
|
||||
imageUrl,
|
||||
imageUrl:response?.imageUrl,
|
||||
activeModel,
|
||||
answer: '',
|
||||
promptTokens: prompt_tokens,
|
||||
@@ -307,7 +307,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
/* 当用户回答一般停止时 也需要扣费 */
|
||||
let charge = deduct;
|
||||
if (isTokenBased === true) {
|
||||
charge = deduct * total_tokens;
|
||||
charge = Math.ceil((deduct * total_tokens) / tokenFeeRatio);
|
||||
}
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, `model${deductType === 1 ? 3 : 4}`, charge, total_tokens);
|
||||
});
|
||||
@@ -320,11 +320,11 @@ export class ChatgptService implements OnModuleInit {
|
||||
const { context: messagesHistory } = await this.nineStore.buildMessageFromParentMessageId(usingNetwork ? netWorkPrompt : prompt, {
|
||||
parentMessageId,
|
||||
systemMessage,
|
||||
imageUrl,
|
||||
activeModel,
|
||||
maxModelToken: maxToken,
|
||||
maxResponseTokens: maxTokenRes,
|
||||
maxRounds: addOneIfOdd(rounds),
|
||||
imageUrl,
|
||||
activeModel,
|
||||
});
|
||||
let firstChunk = true;
|
||||
response = await sendMessageFromOpenAi(messagesHistory, {
|
||||
@@ -332,8 +332,9 @@ export class ChatgptService implements OnModuleInit {
|
||||
maxTokenRes,
|
||||
apiKey: modelKey,
|
||||
model,
|
||||
imageUrl,
|
||||
prompt,
|
||||
activeModel,
|
||||
imageUrl,
|
||||
temperature,
|
||||
proxyUrl: proxyResUrl,
|
||||
onProgress: (chat) => {
|
||||
@@ -341,7 +342,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
lastChat = chat;
|
||||
firstChunk = false;
|
||||
},
|
||||
});
|
||||
},this.uploadService);
|
||||
isSuccess = true;
|
||||
}
|
||||
|
||||
@@ -385,7 +386,6 @@ export class ChatgptService implements OnModuleInit {
|
||||
isSuccess = true;
|
||||
}
|
||||
|
||||
/* 分别将本次用户输入的 和 机器人返回的分两次存入到 store */
|
||||
const userMessageData: MessageInfo = {
|
||||
id: this.nineStore.getUuid(),
|
||||
text: prompt,
|
||||
@@ -407,7 +407,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
text: response.text,
|
||||
role: 'assistant',
|
||||
name: undefined,
|
||||
usage: response.usage,
|
||||
usage: response?.usage,
|
||||
imageUrl: response?.imageUrl,
|
||||
parentMessageId: userMessageData.id,
|
||||
conversationId: response?.conversationId,
|
||||
};
|
||||
@@ -415,7 +416,6 @@ export class ChatgptService implements OnModuleInit {
|
||||
await this.nineStore.setData(assistantMessageData);
|
||||
|
||||
othersInfo = { model, parentMessageId: userMessageData.id };
|
||||
/* 回答完毕 */
|
||||
} else {
|
||||
const { key, maxToken, maxTokenRes, proxyResUrl } = await this.formatModelToken(currentRequestModelKey);
|
||||
const { parentMessageId, completionParams, systemMessage } = mergedOptions;
|
||||
@@ -431,17 +431,22 @@ export class ChatgptService implements OnModuleInit {
|
||||
temperature,
|
||||
proxyUrl: proxyResUrl,
|
||||
onProgress: null,
|
||||
prompt,
|
||||
});
|
||||
}
|
||||
|
||||
/* 统一最终输出格式 */
|
||||
const formatResponse = await unifiedFormattingResponse(keyType, response, othersInfo);
|
||||
const { prompt_tokens = 0, completion_tokens = 0, total_tokens = 0 } = formatResponse.usage;
|
||||
|
||||
let usage = null;
|
||||
let formatResponse = null;
|
||||
if (model.includes('dall')) {
|
||||
usage = response.detail?.usage || { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 };
|
||||
} else {
|
||||
formatResponse = await unifiedFormattingResponse(keyType, response, othersInfo);
|
||||
}
|
||||
const { prompt_tokens, completion_tokens, total_tokens } = model.includes('dall') ? usage : formatResponse.usage;
|
||||
/* 区分扣除普通还是高级余额 model3: 普通余额 model4: 高级余额 */
|
||||
let charge = deduct;
|
||||
if (isTokenBased === true) {
|
||||
charge = deduct * total_tokens;
|
||||
charge = Math.ceil((deduct * total_tokens) / tokenFeeRatio);
|
||||
}
|
||||
await this.userBalanceService.deductFromBalance(req.user.id, `model${deductType === 1 ? 3 : 4}`, charge, total_tokens);
|
||||
|
||||
@@ -457,13 +462,13 @@ export class ChatgptService implements OnModuleInit {
|
||||
userId: req.user.id,
|
||||
type: DeductionKey.CHAT_TYPE,
|
||||
prompt,
|
||||
imageUrl,
|
||||
imageUrl: response?.imageUrl,
|
||||
activeModel,
|
||||
answer: '',
|
||||
promptTokens: prompt_tokens,
|
||||
completionTokens: 0,
|
||||
totalTokens: total_tokens,
|
||||
model: formatResponse.model,
|
||||
model: model,
|
||||
role: 'user',
|
||||
groupId,
|
||||
requestOptions: JSON.stringify({
|
||||
@@ -479,7 +484,8 @@ export class ChatgptService implements OnModuleInit {
|
||||
userId: req.user.id,
|
||||
type: DeductionKey.CHAT_TYPE,
|
||||
prompt: prompt,
|
||||
answer: formatResponse?.text,
|
||||
imageUrl: response?.imageUrl,
|
||||
answer: response.text,
|
||||
promptTokens: prompt_tokens,
|
||||
completionTokens: completion_tokens,
|
||||
totalTokens: total_tokens,
|
||||
@@ -501,7 +507,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
}),
|
||||
});
|
||||
Logger.debug(
|
||||
`本次调用: ${req.user.id} model: ${model} key -> ${key}, 模型名称: ${modelName}, 最大回复token: ${maxResponseTokens}`,
|
||||
`用户ID: ${req.user.id} 模型名称: ${modelName}-${activeModel}, 消耗token: ${total_tokens}, 消耗积分: ${charge}`,
|
||||
'ChatgptService',
|
||||
);
|
||||
const userBalance = await this.userBalanceService.queryUserBalance(req.user.id);
|
||||
@@ -599,7 +605,7 @@ export class ChatgptService implements OnModuleInit {
|
||||
await this.userBalanceService.validateBalance(req, 'mjDraw', money);
|
||||
let images = [];
|
||||
/* 从3的卡池随机拿一个key */
|
||||
const detailKeyInfo = await this.modelsService.getRandomDrawKey();
|
||||
const detailKeyInfo = await this.modelsService.getCurrentModelKeyInfo('dall-e-3');
|
||||
const keyId = detailKeyInfo?.id;
|
||||
const { key, proxyResUrl } = await this.formatModelToken(detailKeyInfo);
|
||||
Logger.log(`draw paompt info <==**==> ${body.prompt}, key ===> ${key}`, 'DrawService');
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import axios, { AxiosRequestConfig, AxiosResponse } from 'axios';
|
||||
import { get_encoding } from '@dqbd/tiktoken'
|
||||
import { removeSpecialCharacters } from '@/common/utils';
|
||||
import { ConsoleLogger, HttpException, HttpStatus, Logger } from '@nestjs/common';
|
||||
import * as uuid from 'uuid';
|
||||
|
||||
|
||||
const tokenizer = get_encoding('cl100k_base')
|
||||
|
||||
@@ -11,96 +14,166 @@ interface SendMessageResult {
|
||||
detail?: any;
|
||||
}
|
||||
|
||||
function getFullUrl(proxyUrl){
|
||||
function getFullUrl(proxyUrl) {
|
||||
const processedUrl = proxyUrl.endsWith('/') ? proxyUrl.slice(0, -1) : proxyUrl;
|
||||
const baseUrl = processedUrl || 'https://api.openai.com'
|
||||
return `${baseUrl}/v1/chat/completions`
|
||||
}
|
||||
|
||||
export function sendMessageFromOpenAi(messagesHistory, inputs ){
|
||||
const { onProgress, maxToken, apiKey, model, temperature = 0.95, proxyUrl } = inputs
|
||||
console.log('current request options: ',apiKey, model, maxToken, proxyUrl );
|
||||
const max_tokens = compilerToken(model, maxToken)
|
||||
const options: AxiosRequestConfig = {
|
||||
method: 'POST',
|
||||
url: getFullUrl(proxyUrl),
|
||||
responseType: 'stream',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${removeSpecialCharacters(apiKey)}`,
|
||||
},
|
||||
data: {
|
||||
max_tokens,
|
||||
stream: true,
|
||||
temperature,
|
||||
model,
|
||||
messages: messagesHistory
|
||||
},
|
||||
};
|
||||
const prompt = messagesHistory[messagesHistory.length-1]?.content
|
||||
return new Promise(async (resolve, reject) =>{
|
||||
export async function sendMessageFromOpenAi(messagesHistory, inputs, uploadService?) {
|
||||
const { onProgress, maxToken, apiKey, model, temperature = 0.8, proxyUrl, prompt } = inputs
|
||||
if (model.includes('dall')) {
|
||||
let result: any = { text: '', imageUrl: '' };
|
||||
try {
|
||||
const options: AxiosRequestConfig = {
|
||||
method: 'POST',
|
||||
url: `${proxyUrl}/v1/images/generations`,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
data: {
|
||||
prompt: prompt,
|
||||
model: model,
|
||||
response_format: 'b64_json'
|
||||
},
|
||||
}
|
||||
const response: any = await axios(options);
|
||||
const stream = response.data;
|
||||
let result: any = { text: '' };
|
||||
stream.on('data', (chunk) => {
|
||||
const splitArr = chunk.toString().split('\n\n').filter((line) => line.trim() !== '');
|
||||
for (const line of splitArr) {
|
||||
const data = line.replace('data:', '');
|
||||
let ISEND = false;
|
||||
try {
|
||||
ISEND = JSON.parse(data).choices[0].finish_reason === 'stop';
|
||||
} catch (error) {
|
||||
ISEND = false;
|
||||
}
|
||||
/* 如果结束 返回所有 */
|
||||
if (data === '[DONE]' || ISEND) {
|
||||
result.text = result.text.trim();
|
||||
return result;
|
||||
}
|
||||
try {
|
||||
const parsedData = JSON.parse(data);
|
||||
if (parsedData.id) {
|
||||
result.id = parsedData.id;
|
||||
}
|
||||
if (parsedData.choices?.length) {
|
||||
const delta = parsedData.choices[0].delta;
|
||||
result.delta = delta.content;
|
||||
if (delta?.content) result.text += delta.content;
|
||||
if (delta.role) {
|
||||
result.role = delta.role;
|
||||
}
|
||||
result.detail = parsedData;
|
||||
}
|
||||
onProgress && onProgress({text:result.text})
|
||||
} catch (error) {
|
||||
console.log('parse Error', data )
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
stream.on('end', () => {
|
||||
// 手动计算token
|
||||
if(result.detail && result.text){
|
||||
const promptTokens = getTokenCount(prompt)
|
||||
const completionTokens = getTokenCount(result.text)
|
||||
result.detail.usage = {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens ,
|
||||
total_tokens: promptTokens + completionTokens,
|
||||
estimated: true
|
||||
}
|
||||
}
|
||||
return resolve(result);
|
||||
});
|
||||
const { b64_json, revised_prompt } = response.data.data[0]
|
||||
const buffer = Buffer.from(b64_json, 'base64');
|
||||
let imgUrl = '';
|
||||
try {
|
||||
const filename = uuid.v4().slice(0, 10) + '.png';
|
||||
Logger.debug(`------> 开始上传图片!!!`, 'MidjourneyService');
|
||||
const buffer = Buffer.from(b64_json, 'base64');
|
||||
// imgUrl = await uploadService.uploadFileFromUrl({ filename, url })
|
||||
imgUrl = await uploadService.uploadFile({ filename, buffer });
|
||||
Logger.debug(`图片上传成功,URL: ${imgUrl}`, 'MidjourneyService');
|
||||
} catch (error) {
|
||||
Logger.error(`上传图片过程中出现错误: ${error}`, 'MidjourneyService');
|
||||
}
|
||||
result.imageUrl = imgUrl
|
||||
result.text = revised_prompt;
|
||||
onProgress && onProgress({ text: result.text })
|
||||
return result;
|
||||
} catch (error) {
|
||||
reject(error)
|
||||
const status = error?.response?.status || 500;
|
||||
console.log('openai-draw error: ', JSON.stringify(error), status);
|
||||
const message = error?.response?.data?.error?.message;
|
||||
if (status === 429) {
|
||||
result.text = '当前请求已过载、请稍等会儿再试试吧!';
|
||||
return result;
|
||||
}
|
||||
if (status === 400 && message.includes('This request has been blocked by our content filters')) {
|
||||
result.text = '您的请求已被系统拒绝。您的提示可能存在一些非法的文本。';
|
||||
return result;
|
||||
}
|
||||
if (status === 400 && message.includes('Billing hard limit has been reached')) {
|
||||
result.text = '当前模型key已被封禁、已冻结当前调用Key、尝试重新对话试试吧!';
|
||||
return result;
|
||||
}
|
||||
if (status === 500) {
|
||||
result.text = '绘制图片失败,请检查你的提示词是否有非法描述!';
|
||||
return result;
|
||||
}
|
||||
if (status === 401) {
|
||||
result.text = '绘制图片失败,此次绘画被拒绝了!';
|
||||
return result;
|
||||
}
|
||||
result.text = '绘制图片失败,请稍后试试吧!';
|
||||
return result;
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let result: any = { text: '' };
|
||||
const options: AxiosRequestConfig = {
|
||||
method: 'POST',
|
||||
url: getFullUrl(proxyUrl),
|
||||
responseType: 'stream',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
data: {
|
||||
stream: true,
|
||||
temperature,
|
||||
model,
|
||||
messages: messagesHistory,
|
||||
},
|
||||
};
|
||||
|
||||
if (model === 'gpt-4-vision-preview') {
|
||||
options.data.max_tokens = 2048;
|
||||
}
|
||||
|
||||
return new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const response: any = await axios(options);
|
||||
const stream = response.data;
|
||||
|
||||
stream.on('data', (chunk) => {
|
||||
const splitArr = chunk.toString().split('\n\n').filter((line) => line.trim() !== '');
|
||||
for (const line of splitArr) {
|
||||
const data = line.replace('data:', '');
|
||||
let ISEND = false;
|
||||
try {
|
||||
ISEND = JSON.parse(data).choices[0].finish_reason === 'stop';
|
||||
} catch (error) {
|
||||
ISEND = false;
|
||||
}
|
||||
/* 如果结束 返回所有 */
|
||||
if (ISEND) {
|
||||
result.text = result.text.trim();
|
||||
return result;
|
||||
}
|
||||
try {
|
||||
if (data !== " [DONE]" && data !== "[DONE]" && data != "[DONE] ") {
|
||||
const parsedData = JSON.parse(data);
|
||||
if (parsedData.id) {
|
||||
result.id = parsedData.id;
|
||||
}
|
||||
if (parsedData.choices?.length) {
|
||||
const delta = parsedData.choices[0].delta;
|
||||
result.delta = delta.content;
|
||||
if (delta?.content) result.text += delta.content;
|
||||
if (delta.role) {
|
||||
result.role = delta.role;
|
||||
}
|
||||
result.detail = parsedData;
|
||||
}
|
||||
onProgress && onProgress({ text: result.text })
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('parse Error', data)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let totalText = '';
|
||||
messagesHistory.forEach(message => {
|
||||
totalText += message.content + ' ';
|
||||
});
|
||||
stream.on('end', () => {
|
||||
// 手动计算token
|
||||
if (result.detail && result.text) {
|
||||
const promptTokens = getTokenCount(totalText)
|
||||
const completionTokens = getTokenCount(result.text)
|
||||
result.detail.usage = {
|
||||
prompt_tokens: promptTokens,
|
||||
completion_tokens: completionTokens,
|
||||
total_tokens: promptTokens + completionTokens,
|
||||
estimated: true
|
||||
}
|
||||
}
|
||||
return resolve(result);
|
||||
});
|
||||
} catch (error) {
|
||||
reject(error)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
export function getTokenCount(text: string) {
|
||||
if (!text) return 0;
|
||||
// 确保text是字符串类型
|
||||
@@ -111,28 +184,4 @@ export function getTokenCount(text: string) {
|
||||
return tokenizer.encode(text).length
|
||||
}
|
||||
|
||||
function compilerToken(model, maxToken){
|
||||
let max = 0
|
||||
|
||||
/* 3.5 */
|
||||
if(model.includes(3.5)){
|
||||
max = maxToken > 4096 ? 4096 : maxToken
|
||||
}
|
||||
|
||||
/* 4.0 */
|
||||
if(model.includes('gpt-4')){
|
||||
max = maxToken > 8192 ? 8192 : maxToken
|
||||
}
|
||||
|
||||
/* 4.0 preview */
|
||||
if(model.includes('preview')){
|
||||
max = maxToken > 4096 ? 4096 : maxToken
|
||||
}
|
||||
|
||||
/* 4.0 32k */
|
||||
if(model.includes('32k')){
|
||||
max = maxToken > 32768 ? 32768 : maxToken
|
||||
}
|
||||
|
||||
return max
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import Keyv from 'keyv';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { get_encoding } from '@dqbd/tiktoken';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { includes } from 'lodash';
|
||||
|
||||
const tokenizer = get_encoding('cl100k_base');
|
||||
|
||||
@@ -96,7 +97,7 @@ export class NineStore implements NineStoreInterface {
|
||||
let nextNumTokensEstimate = 0;
|
||||
// messages.push({ role: 'system', content: systemMessage, name })
|
||||
if (systemMessage) {
|
||||
const specialModels = ['gemini-pro', 'ERNIE', 'qwen', 'SparkDesk', 'hunyuan'];
|
||||
const specialModels = ['gemini-pro', 'ERNIE','hunyuan'];
|
||||
const isSpecialModel = activeModel && specialModels.some((specialModel) => activeModel.includes(specialModel));
|
||||
if (isSpecialModel) {
|
||||
messages.push({ role: 'user', content: systemMessage, name });
|
||||
@@ -146,7 +147,7 @@ export class NineStore implements NineStoreInterface {
|
||||
let content = text; // 默认情况下使用text作为content
|
||||
|
||||
// 特别处理包含 imageUrl 的消息
|
||||
if (role === 'user' && imageUrl) {
|
||||
if (imageUrl) {
|
||||
if (activeModel === 'gpt-4-vision-preview') {
|
||||
content = [
|
||||
{ type: 'text', text: text },
|
||||
|
||||
@@ -13,7 +13,7 @@ interface UserInfo {
|
||||
|
||||
@Injectable()
|
||||
export class DatabaseService implements OnModuleInit {
|
||||
constructor(private connection: Connection) {}
|
||||
constructor(private connection: Connection) { }
|
||||
async onModuleInit() {
|
||||
await this.checkSuperAdmin();
|
||||
await this.checkSiteBaseConfig();
|
||||
@@ -23,7 +23,7 @@ export class DatabaseService implements OnModuleInit {
|
||||
async checkSuperAdmin() {
|
||||
const user = await this.connection.query(`SELECT * FROM users WHERE role = 'super'`);
|
||||
if (!user || user.length === 0) {
|
||||
const superPassword = bcrypt.hashSync('123456', 10);
|
||||
const superPassword = bcrypt.hashSync('123456', 10); //初始密码
|
||||
const adminPassword = bcrypt.hashSync('123456', 10);
|
||||
const superEmail = 'default@cooper.com';
|
||||
const adminEmail = 'defaultAdmin@cooper.com';
|
||||
@@ -44,7 +44,7 @@ export class DatabaseService implements OnModuleInit {
|
||||
const userId = user.insertId;
|
||||
const balance = await this.connection.query(`INSERT INTO balance (userId, balance, usesLeft, paintCount) VALUES ('${userId}', 0, 1000, 100)`);
|
||||
Logger.log(
|
||||
`初始化创建${role}用户成功、用户名为[${username}]、初始密码为[${username === 'super' ? '123456' : '123456'}] ==============> 请注意查阅`,
|
||||
`初始化创建${role}用户成功、用户名为[${username}]、初始密码为[${username === 'super' ? 'nine-super' : '123456'}] ==============> 请注意查阅`,
|
||||
'DatabaseService',
|
||||
);
|
||||
} catch (error) {
|
||||
@@ -68,17 +68,7 @@ export class DatabaseService implements OnModuleInit {
|
||||
/* 创建基础的网站数据 */
|
||||
async createBaseSiteConfig() {
|
||||
try {
|
||||
const code = `
|
||||
<script>
|
||||
var _hmt = _hmt || [];
|
||||
(function() {
|
||||
var hm = document.createElement("script");
|
||||
hm.src = "https://hm.baidu.com/hm.js?cb8c9a3bcadbc200e950b05f9c61a385";
|
||||
var s = document.getElementsByTagName("script")[0];
|
||||
s.parentNode.insertBefore(hm, s);
|
||||
})();
|
||||
</script>
|
||||
`;
|
||||
const code = ``;
|
||||
|
||||
const noticeInfo = `
|
||||
#### YiAi 欢迎您
|
||||
@@ -88,21 +78,21 @@ export class DatabaseService implements OnModuleInit {
|
||||
`;
|
||||
|
||||
const defaultConfig = [
|
||||
{ configKey: 'siteName', configVal: 'Nine Ai', public: 1, encry: 0 },
|
||||
{ configKey: 'qqNumber', configVal: '840814166', public: 1, encry: 0 },
|
||||
{ configKey: 'vxNumber', configVal: 'wangpanzhu321', public: 1, encry: 0 },
|
||||
{ configKey: 'siteName', configVal: 'Yi Ai', public: 1, encry: 0 },
|
||||
{ configKey: 'qqNumber', configVal: '805239273', public: 1, encry: 0 },
|
||||
{ configKey: 'vxNumber', configVal: 'HelloWordYi819', public: 1, encry: 0 },
|
||||
{ configKey: 'robotAvatar', configVal: '', public: 1, encry: 0 },
|
||||
{
|
||||
configKey: 'userDefautlAvatar',
|
||||
configVal: 'https://public-1300678944.cos.ap-shanghai.myqcloud.com/blog/1682571295452image.png',
|
||||
configVal: '',
|
||||
public: 0,
|
||||
encry: 0,
|
||||
},
|
||||
{ configKey: 'baiduCode', configVal: code, public: 1, encry: 0 },
|
||||
{ configKey: 'baiduSiteId', configVal: '19024441', public: 0, encry: 0 },
|
||||
{ configKey: 'baiduSiteId', configVal: '', public: 0, encry: 0 },
|
||||
{
|
||||
configKey: 'baiduToken',
|
||||
configVal: '121.a1600b9b60910feea2ef627ea9776a6f.YGP_CWCOA2lNcIGJ27BwXGxa6nZhBQyLUS4XVaD.TWt9TA',
|
||||
configVal: '',
|
||||
public: 0,
|
||||
encry: 0,
|
||||
},
|
||||
@@ -110,25 +100,25 @@ export class DatabaseService implements OnModuleInit {
|
||||
{ configKey: 'openaiBaseUrl', configVal: 'https://api.openai.com', public: 0, encry: 0 },
|
||||
{ configKey: 'noticeInfo', configVal: noticeInfo, public: 1, encry: 0 },
|
||||
|
||||
{ configKey: 'registerVerifyEmailTitle', configVal: 'NineTeam团队账号验证', public: 0, encry: 0 },
|
||||
{ configKey: 'registerVerifyEmailTitle', configVal: 'Yi Ai团队账号验证', public: 0, encry: 0 },
|
||||
{
|
||||
configKey: 'registerVerifyEmailDesc',
|
||||
configVal: '欢迎使用Nine Team团队的产品服务,请在五分钟内完成你的账号激活,点击以下按钮激活您的账号,',
|
||||
configVal: '欢迎使用Yi Ai团队的产品服务,请在五分钟内完成你的账号激活,点击以下按钮激活您的账号,',
|
||||
public: 0,
|
||||
encry: 0,
|
||||
},
|
||||
{ configKey: 'registerVerifyEmailFrom', configVal: 'NineTeam团队', public: 0, encry: 0 },
|
||||
{ configKey: 'registerVerifyEmailFrom', configVal: 'Yi Ai团队', public: 0, encry: 0 },
|
||||
{ configKey: 'registerVerifyExpir', configVal: '1800', public: 0, encry: 0 },
|
||||
{ configKey: 'registerSuccessEmailTitle', configVal: 'NineTeam团队账号激活成功', public: 0, encry: 0 },
|
||||
{ configKey: 'registerSuccessEmailTeamName', configVal: 'NineTeam团队', public: 0, encry: 0 },
|
||||
{ configKey: 'registerSuccessEmailTitle', configVal: 'Yi Ai账号激活成功', public: 0, encry: 0 },
|
||||
{ configKey: 'registerSuccessEmailTeamName', configVal: 'Yi Ai', public: 0, encry: 0 },
|
||||
{
|
||||
configKey: 'registerSuccessEmaileAppend',
|
||||
configVal: ',请妥善保管您的账号,我们将为您赠送50次对话额度和5次绘画额度、祝您使用愉快',
|
||||
configVal: ',请妥善保管您的账号,祝您使用愉快',
|
||||
public: 0,
|
||||
encry: 0,
|
||||
},
|
||||
{ configKey: 'registerFailEmailTitle', configVal: 'NineTeam账号激活失败', public: 0, encry: 0 },
|
||||
{ configKey: 'registerFailEmailTeamName', configVal: 'NineTeam团队', public: 0, encry: 0 },
|
||||
{ configKey: 'registerFailEmailTitle', configVal: 'Yi Ai账号激活失败', public: 0, encry: 0 },
|
||||
{ configKey: 'registerFailEmailTeamName', configVal: 'Yi Ai团队', public: 0, encry: 0 },
|
||||
/* 注册默认设置 */
|
||||
{ configKey: 'registerSendStatus', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'registerSendModel3Count', configVal: '30', public: 1, encry: 0 },
|
||||
@@ -136,16 +126,16 @@ export class DatabaseService implements OnModuleInit {
|
||||
{ configKey: 'registerSendDrawMjCount', configVal: '3', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRegisterSendStatus', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRegisterSendRank', configVal: '500', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendModel3Count', configVal: '20', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendModel4Count', configVal: '2', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendDrawMjCount', configVal: '3', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendModel3Count', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendModel4Count', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'firstRregisterSendDrawMjCount', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteSendStatus', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendModel3Count', configVal: '30', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendModel4Count', configVal: '3', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendDrawMjCount', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendModel3Count', configVal: '0', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendModel4Count', configVal: '0', public: 1, encry: 0 },
|
||||
{ configKey: 'inviteGiveSendDrawMjCount', configVal: '0', public: 1, encry: 0 },
|
||||
{ configKey: 'invitedGuestSendModel3Count', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'invitedGuestSendModel4Count', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'invitedGuestSendDrawMjCount', configVal: '1', public: 1, encry: 0 },
|
||||
{ configKey: 'invitedGuestSendModel4Count', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'invitedGuestSendDrawMjCount', configVal: '10', public: 1, encry: 0 },
|
||||
{ configKey: 'isVerifyEmail', configVal: '1', public: 1, encry: 0 },
|
||||
];
|
||||
|
||||
|
||||
@@ -22,9 +22,6 @@ export class MidjourneyEntity extends BaseEntity {
|
||||
@Column({ comment: '垫图图片 + 绘画描述词 + 额外参数 = 完整的prompt', type: 'text' })
|
||||
fullPrompt: string;
|
||||
|
||||
@Column({ comment: '随机产生的绘画ID用于拿取比对结果' })
|
||||
randomDrawId: string;
|
||||
|
||||
@Column({ comment: '当前绘制任务的进度', nullable: true })
|
||||
progress: number;
|
||||
|
||||
@@ -35,7 +32,7 @@ export class MidjourneyEntity extends BaseEntity {
|
||||
status: number;
|
||||
|
||||
@Column({ comment: 'mj绘画的动作、绘图、放大、变换、图生图' })
|
||||
action: number;
|
||||
action: string;
|
||||
|
||||
@Column({ comment: '一组图片的第几张、放大或者变换的时候需要使用', nullable: true })
|
||||
orderId: number;
|
||||
@@ -43,14 +40,17 @@ export class MidjourneyEntity extends BaseEntity {
|
||||
@Column({ comment: '是否推荐0: 默认不推荐 1: 推荐', nullable: true, default: 0 })
|
||||
rec: number;
|
||||
|
||||
@Column({ comment: '对图片操作的', nullable: true })
|
||||
customId: string;
|
||||
|
||||
@Column({ comment: '绘画的ID每条不一样', nullable: true })
|
||||
message_id: string;
|
||||
drawId: string;
|
||||
|
||||
@Column({ comment: '对图片放大或者变体的ID', nullable: true })
|
||||
custom_id: string;
|
||||
@Column({ comment: '图片链接', nullable: true, type: 'text' })
|
||||
drawUrl: string;
|
||||
|
||||
@Column({ comment: '图片信息尺寸', nullable: true, type: 'text' })
|
||||
fileInfo: string;
|
||||
@Column({ comment: '图片比例', nullable: true, type: 'text' })
|
||||
drawRatio: string;
|
||||
|
||||
@Column({ comment: '扩展参数', nullable: true, type: 'text' })
|
||||
extend: string;
|
||||
@@ -60,4 +60,5 @@ export class MidjourneyEntity extends BaseEntity {
|
||||
|
||||
@Column({ comment: '是否存入了图片到配置的储存项 配置了则存储 不配置地址则是源地址', default: true })
|
||||
isSaveImg: boolean;
|
||||
messageId: any;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
import { AddBadWordDto } from '../../badwords/dto/addBadWords.dto';
|
||||
import { AddBadWordDto } from './../../badwords/dto/addBadWords.dto';
|
||||
import { IsNotEmpty, MinLength, MaxLength, IsEmail, IsOptional, IsNumber } from 'class-validator';
|
||||
import { ApiProperty } from '@nestjs/swagger';
|
||||
|
||||
@@ -27,6 +27,9 @@ export class SetModelDto {
|
||||
@ApiProperty({ example: 1, description: 'key的权重' })
|
||||
keyWeight: number;
|
||||
|
||||
@ApiProperty({ example: 1, description: '模型排序' })
|
||||
modelOrder: number;
|
||||
|
||||
@ApiProperty({ example: 4096, description: '模型支持的最大TOken数量', required: true })
|
||||
maxModelTokens: number;
|
||||
|
||||
@@ -53,9 +56,10 @@ export class SetModelDto {
|
||||
|
||||
@ApiProperty({ example: true, description: '是否设置为绘画Key', required: false })
|
||||
isDraw: boolean;
|
||||
//设置token计费
|
||||
|
||||
@ApiProperty({ example: true, description: '是否使用token计费', required: false })
|
||||
isTokenBased: boolean;
|
||||
|
||||
@ApiProperty({ example: true, description: 'token计费比例', required: false })
|
||||
tokenFeeRatio: number;
|
||||
}
|
||||
|
||||
@@ -42,4 +42,7 @@ export class ModelsTypeEntity extends BaseEntity {
|
||||
|
||||
@Column({ comment: '是否为特殊模型、可以提供联想翻译、思维导图等特殊操作', default: 0 })
|
||||
isUseTool: boolean;
|
||||
|
||||
@Column({ comment: '模型排序', default: 1 })
|
||||
modelOrder: number;
|
||||
}
|
||||
|
||||
@@ -69,4 +69,8 @@ export class ModelsEntity extends BaseEntity {
|
||||
|
||||
@Column({ comment: 'token计费比例', default: 0 })
|
||||
tokenFeeRatio: number;
|
||||
|
||||
@Column({ comment: 'key权重', default: 1 })
|
||||
modelOrder: number;
|
||||
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { ModelsEntity } from './models.entity';
|
||||
import { SetModelDto } from './dto/setModel.dto';
|
||||
import { QueryModelDto } from './dto/queryModel.dto';
|
||||
import { ModelsMapCn } from '@/common/constants/status.constant';
|
||||
import { getAccessToken } from '../chatgpt/baidu';
|
||||
// import { getAccessToken } from '../chatgpt/baidu';
|
||||
import { getRandomItemFromArray, hideString } from '@/common/utils';
|
||||
import { ModelsTypeEntity } from './modelType.entity';
|
||||
import { SetModelTypeDto } from './dto/setModelType.dto';
|
||||
@@ -19,8 +19,8 @@ export class ModelsService {
|
||||
private readonly modelsEntity: Repository<ModelsEntity>,
|
||||
@InjectRepository(ModelsTypeEntity)
|
||||
private readonly modelsTypeEntity: Repository<ModelsTypeEntity>,
|
||||
){}
|
||||
|
||||
) { }
|
||||
|
||||
private modelTypes = []
|
||||
private modelMaps = {}
|
||||
private keyList = {}
|
||||
@@ -28,145 +28,145 @@ export class ModelsService {
|
||||
private keyPoolMap = {} // 记录每个模型的所有key 并且记录顺序
|
||||
private keyPoolIndexMap = {} // 记录每个模型的当前调用的下标
|
||||
|
||||
async onModuleInit(){
|
||||
async onModuleInit() {
|
||||
await this.initCalcKey()
|
||||
this.refreshBaiduAccesstoken()
|
||||
}
|
||||
|
||||
/* 初始化整理所有key 进行分类并且默认一个初始模型配置 默认是配置的第一个分类的第一个key为准 */
|
||||
async initCalcKey(){
|
||||
async initCalcKey() {
|
||||
this.keyPoolMap = {}
|
||||
this.keyPoolIndexMap = {}
|
||||
this.keyList = {}
|
||||
this.modelMaps = {}
|
||||
this.modelTypes = []
|
||||
const allKeys = await this.modelsEntity.find({where: { status: true }})
|
||||
const keyTypes = allKeys.reduce( (pre: any, cur ) => {
|
||||
if(!pre[cur.keyType]){
|
||||
const allKeys = await this.modelsEntity.find({ where: { status: true } })
|
||||
const keyTypes = allKeys.reduce((pre: any, cur) => {
|
||||
if (!pre[cur.keyType]) {
|
||||
pre[cur.keyType] = [cur]
|
||||
}else{
|
||||
} else {
|
||||
pre[cur.keyType].push(cur)
|
||||
}
|
||||
return pre
|
||||
}, {})
|
||||
this.modelTypes = Object.keys(keyTypes).map( keyType => {
|
||||
return { label: ModelsMapCn[keyType] , val: keyType}
|
||||
this.modelTypes = Object.keys(keyTypes).map(keyType => {
|
||||
return { label: ModelsMapCn[keyType], val: keyType }
|
||||
})
|
||||
this.modelMaps = keyTypes
|
||||
this.keyList = {}
|
||||
allKeys.forEach( keyDetail => {
|
||||
|
||||
allKeys.forEach(keyDetail => {
|
||||
const { keyType, model, keyWeight } = keyDetail
|
||||
if(!this.keyPoolMap[model]) this.keyPoolMap[model] = []
|
||||
if (!this.keyPoolMap[model]) this.keyPoolMap[model] = []
|
||||
for (let index = 0; index < keyWeight; index++) {
|
||||
this.keyPoolMap[model].push(keyDetail)
|
||||
}
|
||||
if(!this.keyPoolIndexMap[model]) this.keyPoolIndexMap[model] = 0
|
||||
if(!this.keyList[keyType]) this.keyList[keyType] = {}
|
||||
if(!this.keyList[keyType][model]) this.keyList[keyType][model] = []
|
||||
if (!this.keyPoolIndexMap[model]) this.keyPoolIndexMap[model] = 0
|
||||
if (!this.keyList[keyType]) this.keyList[keyType] = {}
|
||||
if (!this.keyList[keyType][model]) this.keyList[keyType][model] = []
|
||||
this.keyList[keyType][model].push(keyDetail)
|
||||
})
|
||||
}
|
||||
|
||||
/* lock key 自动锁定key */
|
||||
async lockKey(keyId, remark, keyStatus = -1){
|
||||
async lockKey(keyId, remark, keyStatus = -1) {
|
||||
const res = await this.modelsEntity.update({ id: keyId }, { status: false, keyStatus, remark });
|
||||
Logger.error(`key: ${keyId} 欠费或被官方封禁导致不可用,已被系统自动锁定`);
|
||||
this.initCalcKey()
|
||||
}
|
||||
|
||||
/* 获取本次调用key的详细信息 */
|
||||
async getCurrentModelKeyInfo(model){
|
||||
if(!this.keyPoolMap[model]){
|
||||
async getCurrentModelKeyInfo(model) {
|
||||
if (!this.keyPoolMap[model]) {
|
||||
throw new HttpException('当前调用模型已经被移除、请重新选择模型!', HttpStatus.BAD_REQUEST)
|
||||
}
|
||||
/* 调用下标+1 */
|
||||
this.keyPoolIndexMap[model]++
|
||||
/* 判断下标超出边界没有 */
|
||||
const index = this.keyPoolIndexMap[model]
|
||||
if(index >= this.keyPoolMap[model].length) this.keyPoolIndexMap[model] = 0
|
||||
const key = this.keyPoolMap[model][this.keyPoolIndexMap[model]]
|
||||
if (index >= this.keyPoolMap[model].length) this.keyPoolIndexMap[model] = 0
|
||||
const key = this.keyPoolMap[model][this.keyPoolIndexMap[model]]
|
||||
return key
|
||||
}
|
||||
|
||||
/* 通过现有配置的key和分类给到默认的配置信息 默认给到第一个分类的第一个key的配置 */
|
||||
async getBaseConfig(appId?: number): Promise<any>{
|
||||
if(!this.modelTypes.length || !Object.keys(this.modelMaps).length) return;
|
||||
async getBaseConfig(appId?: number): Promise<any> {
|
||||
if (!this.modelTypes.length || !Object.keys(this.modelMaps).length) return;
|
||||
/* 有appid只可以使用openai 的 模型 */
|
||||
const modelTypeInfo = appId ? this.modelTypes.find( item => Number(item.val) === 1) : this.modelTypes[0]
|
||||
const modelTypeInfo = appId ? this.modelTypes.find(item => Number(item.val) === 1) : this.modelTypes[0]
|
||||
// TODO 第0个会有问题 先添加的4默认就是模型4了 后面优化下
|
||||
if(!modelTypeInfo) return;
|
||||
if (!modelTypeInfo) return;
|
||||
const { keyType, modelName, model, maxModelTokens, maxResponseTokens, deductType, deduct, maxRounds } = this.modelMaps[modelTypeInfo.val][0] // 取到第一个默认的配置项信息
|
||||
return {
|
||||
modelTypeInfo,
|
||||
modelInfo: { keyType, modelName, model, maxModelTokens, maxResponseTokens, topN: 0.8, systemMessage: '', deductType, deduct, maxRounds, rounds: 8 }
|
||||
modelInfo: { keyType, modelName, model, maxModelTokens, maxResponseTokens, topN: 0.8, systemMessage: '', deductType, deduct, maxRounds, rounds: 8 }
|
||||
}
|
||||
}
|
||||
|
||||
async setModel(params: SetModelDto){
|
||||
try {
|
||||
const { id } = params
|
||||
params.status && (params.keyStatus = 1)
|
||||
if(id){
|
||||
const res = await this.modelsEntity.update({id}, params)
|
||||
await this.initCalcKey()
|
||||
return res.affected > 0
|
||||
}else{
|
||||
const { keyType, key } = params
|
||||
if(Number(keyType !== 1)){
|
||||
const res = await this.modelsEntity.save(params)
|
||||
async setModel(params: SetModelDto) {
|
||||
try {
|
||||
const { id } = params
|
||||
params.status && (params.keyStatus = 1)
|
||||
if (id) {
|
||||
const res = await this.modelsEntity.update({ id }, params)
|
||||
await this.initCalcKey()
|
||||
if(keyType === 2){ //百度的需要刷新token
|
||||
this.refreshBaiduAccesstoken()
|
||||
return res.affected > 0
|
||||
} else {
|
||||
const { keyType, key } = params
|
||||
if (Number(keyType !== 1)) {
|
||||
const res = await this.modelsEntity.save(params)
|
||||
await this.initCalcKey()
|
||||
return res
|
||||
} else {
|
||||
const data = key.map(k => {
|
||||
try {
|
||||
const data = JSON.parse(JSON.stringify(params))
|
||||
data.key = k
|
||||
return data
|
||||
} catch (error) {
|
||||
console.log('parse error: ', error);
|
||||
}
|
||||
})
|
||||
const res = await this.modelsEntity.save(data)
|
||||
await this.initCalcKey()
|
||||
return res
|
||||
}
|
||||
return res
|
||||
}else{
|
||||
const data = key.map( k => {
|
||||
try {
|
||||
const data = JSON.parse(JSON.stringify(params))
|
||||
data.key = k
|
||||
return data
|
||||
} catch (error) {
|
||||
console.log('parse error: ', error);
|
||||
}
|
||||
})
|
||||
const res = await this.modelsEntity.save(data)
|
||||
await this.initCalcKey()
|
||||
return res
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('error: ', error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('error: ', error);
|
||||
}
|
||||
}
|
||||
|
||||
async delModel({id}){
|
||||
if(!id) {
|
||||
async delModel({ id }) {
|
||||
if (!id) {
|
||||
throw new HttpException('缺失必要参数!', HttpStatus.BAD_REQUEST)
|
||||
}
|
||||
const m = await this.modelsEntity.findOne({where: {id}})
|
||||
if(!m){
|
||||
const m = await this.modelsEntity.findOne({ where: { id } })
|
||||
if (!m) {
|
||||
throw new HttpException('当前账号不存在!', HttpStatus.BAD_REQUEST)
|
||||
}
|
||||
const res = await this.modelsEntity.delete({id})
|
||||
const res = await this.modelsEntity.delete({ id })
|
||||
await this.initCalcKey()
|
||||
return res;
|
||||
}
|
||||
|
||||
async queryModels(req, params: QueryModelDto){
|
||||
async queryModels(req, params: QueryModelDto) {
|
||||
const { role } = req.user
|
||||
const { keyType, key, status, model, page = 1, size = 10 } = params
|
||||
let where: any = {}
|
||||
keyType && (where.keyType = keyType)
|
||||
model && (where.model = model)
|
||||
status && (where.status = Number(status) === 1 ? true : false)
|
||||
key && ( where.key = Like(`%${key}%`))
|
||||
key && (where.key = Like(`%${key}%`))
|
||||
const [rows, count] = await this.modelsEntity.findAndCount({
|
||||
where: where,
|
||||
skip: (page - 1) * size,
|
||||
take: size,
|
||||
where: where,
|
||||
order: {
|
||||
modelOrder: 'ASC'
|
||||
},
|
||||
skip: (page - 1) * size,
|
||||
take: size,
|
||||
})
|
||||
if(role !== 'super'){
|
||||
rows.forEach( item => {
|
||||
if (role !== 'super') {
|
||||
rows.forEach(item => {
|
||||
item.key && (item.key = hideString(item.key))
|
||||
item.secret && (item.secret = hideString(item.secret))
|
||||
})
|
||||
@@ -176,24 +176,29 @@ export class ModelsService {
|
||||
}
|
||||
|
||||
/* 客户端查询到的所有的配置的模型类别 以及类别下自定义的多少中文模型名称 */
|
||||
async modelsList(){
|
||||
const cloneModelMaps = JSON.parse(JSON.stringify(this.modelMaps))
|
||||
Object.keys(cloneModelMaps).forEach( key => {
|
||||
async modelsList() {
|
||||
const cloneModelMaps = JSON.parse(JSON.stringify(this.modelMaps));
|
||||
Object.keys(cloneModelMaps).forEach(key => {
|
||||
// 对每个模型进行排序
|
||||
cloneModelMaps[key] = cloneModelMaps[key].sort((a, b) => a.modelOrder - b.modelOrder);
|
||||
cloneModelMaps[key] = Array.from(
|
||||
cloneModelMaps[key].map( t => {
|
||||
const { modelName, model, deduct, deductType, maxRounds } = t
|
||||
return { modelName, model, deduct, deductType, maxRounds }
|
||||
}).reduce((map, obj) => map.set(obj.modelName, obj), new Map()).values()
|
||||
cloneModelMaps[key]
|
||||
.map(t => {
|
||||
const { modelName, model, deduct, deductType, maxRounds } = t;
|
||||
return { modelName, model, deduct, deductType, maxRounds };
|
||||
})
|
||||
.reduce((map, obj) => map.set(obj.modelName, obj), new Map()).values()
|
||||
);
|
||||
})
|
||||
});
|
||||
|
||||
return {
|
||||
modelTypeList: this.modelTypes,
|
||||
modelMaps: cloneModelMaps
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/* 记录使用次数和使用的token数量 */
|
||||
async saveUseLog(id, useToken){
|
||||
async saveUseLog(id, useToken) {
|
||||
await this.modelsEntity
|
||||
.createQueryBuilder()
|
||||
.update(ModelsEntity)
|
||||
@@ -202,54 +207,32 @@ export class ModelsService {
|
||||
.execute();
|
||||
}
|
||||
|
||||
async refreshBaiduAccesstoken(){
|
||||
const allKeys = await this.modelsEntity.find({ where: { keyType: 2 } })
|
||||
const keysMap: any = {}
|
||||
allKeys.forEach( keyInfo => {
|
||||
const { key, secret } = keyInfo
|
||||
if(!keysMap.key){
|
||||
keysMap[key] = [{ keyInfo }]
|
||||
}else{
|
||||
keysMap[key].push(keyInfo)
|
||||
}
|
||||
})
|
||||
Object.keys(keysMap).forEach( async key => {
|
||||
const {secret, id } = keysMap[key][0]['keyInfo']
|
||||
const accessToken: any = await getAccessToken(key, secret)
|
||||
await this.modelsEntity.update({ key }, { accessToken })
|
||||
})
|
||||
|
||||
setTimeout(() => {
|
||||
this.initCalcKey()
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
/* 获取一张绘画key */
|
||||
async getRandomDrawKey(){
|
||||
const drawkeys = await this.modelsEntity.find({where: { isDraw: true, status: true }})
|
||||
if(!drawkeys.length){
|
||||
async getRandomDrawKey() {
|
||||
const drawkeys = await this.modelsEntity.find({ where: { isDraw: true, status: true } })
|
||||
if (!drawkeys.length) {
|
||||
throw new HttpException('当前未指定特殊模型KEY、前往后台模型池设置吧!', HttpStatus.BAD_REQUEST)
|
||||
}
|
||||
return getRandomItemFromArray(drawkeys)
|
||||
}
|
||||
|
||||
/* 获取所有key */
|
||||
async getAllKey(){
|
||||
async getAllKey() {
|
||||
return await this.modelsEntity.find()
|
||||
}
|
||||
|
||||
/* 查询模型类型 */
|
||||
async queryModelType(params: QueryModelTypeDto){
|
||||
async queryModelType(params: QueryModelTypeDto) {
|
||||
return 1
|
||||
}
|
||||
|
||||
/* 创建修改模型类型 */
|
||||
async setModelType(params: SetModelTypeDto){
|
||||
async setModelType(params: SetModelTypeDto) {
|
||||
return 1
|
||||
}
|
||||
|
||||
/* 删除模型类型 */
|
||||
async delModelType(params){
|
||||
async delModelType(params) {
|
||||
return 1
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ export class MjDrawDto {
|
||||
@IsOptional()
|
||||
imgUrl?: string;
|
||||
|
||||
@ApiProperty({ example: 1, description: '绘画动作 绘图、放大、变换、图生图' })
|
||||
@ApiProperty({ example: 'IMAGINE', description: '任务类型,可用值:IMAGINE,UPSCALE,VARIATION,ZOOM,PAN,DESCRIBE,BLEND,SHORTEN,SWAP_FACE' })
|
||||
@IsOptional()
|
||||
action: number;
|
||||
action: string;
|
||||
|
||||
@ApiProperty({ example: 1, description: '变体或者放大的序号' })
|
||||
@IsOptional()
|
||||
@@ -31,4 +31,8 @@ export class MjDrawDto {
|
||||
@ApiProperty({ example: 1, description: '绘画的DBID' })
|
||||
@IsOptional()
|
||||
drawId: number;
|
||||
|
||||
@ApiProperty({ example: 1, description: '任务ID' })
|
||||
@IsOptional()
|
||||
taskId: number;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
private readonly midjourneyService: MidjourneyService,
|
||||
private readonly userBalanceService: UserBalanceService,
|
||||
private readonly globalConfigService: GlobalConfigService,
|
||||
) {}
|
||||
) { }
|
||||
private readonly jobIds: any[] = [];
|
||||
|
||||
async onApplicationBootstrap() {
|
||||
@@ -27,14 +27,13 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
|
||||
/* 提交绘画任务 */
|
||||
async addMjDrawQueue(body: MjDrawDto, req: Request) {
|
||||
const { prompt, imgUrl, extraParam, orderId, action = 1, drawId } = body;
|
||||
const { imgUrl, orderId, action, drawId } = body;
|
||||
/* 限制普通用户队列最多可以有两个任务在排队或者等待中 */
|
||||
await this.midjourneyService.checkLimit(req);
|
||||
/* 检测余额 */
|
||||
await this.userBalanceService.validateBalance(req, 'mjDraw', action === 2 ? 1 : 4);
|
||||
|
||||
await this.userBalanceService.validateBalance(req, 'mjDraw', action === 'UPSCALE' ? 1 : 4);
|
||||
/* 绘图或者图生图 */
|
||||
if (action === MidjourneyActionEnum.DRAW || action === MidjourneyActionEnum.GENERATE) {
|
||||
if (action === 'IMAGINE') {
|
||||
/* 绘图或者图生图是相同的 区分一个action即可 */
|
||||
const randomDrawId = `${createRandomUid()}`;
|
||||
const params = { ...body, userId: req.user.id, randomDrawId };
|
||||
@@ -44,84 +43,28 @@ export class QueueService implements OnApplicationBootstrap {
|
||||
/* 添加任务到队列 通过imgUrl判断是不是图生图 */
|
||||
const job = await this.mjDrawQueue.add(
|
||||
'mjDraw',
|
||||
{ id: res.id, action: imgUrl ? 4 : 1, userId: req.user.id },
|
||||
{ id: res.id, action: action, userId: req.user.id },
|
||||
{ delay: 1000, timeout: +timeout },
|
||||
);
|
||||
/* 绘图和图生图扣除余额4 */
|
||||
this.jobIds.push(job.id);
|
||||
/* 扣费 */
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return true;
|
||||
} else {
|
||||
const { orderId, action, drawId } = body;
|
||||
const actionDetail = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!drawId || !orderId) {
|
||||
throw new HttpException('缺少必要参数!', HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
/* 图片操作 */
|
||||
|
||||
/* 图片放大 */
|
||||
if (action === MidjourneyActionEnum.UPSCALE) {
|
||||
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const { custom_id } = actionDetail;
|
||||
/* 检测当前图片是不是已经放大过了 */
|
||||
await this.midjourneyService.checkIsUpscale(custom_id);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
/* 扣费 */
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 1, 1);
|
||||
this.jobIds.push(job.id);
|
||||
return;
|
||||
}
|
||||
|
||||
/* 图片变体 */
|
||||
if (action === MidjourneyActionEnum.VARIATION) {
|
||||
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
/* 扣费 */
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
/* 重新生成 */
|
||||
if (action === MidjourneyActionEnum.REGENERATE) {
|
||||
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
/* 对图片增强 Vary */
|
||||
if (action === MidjourneyActionEnum.VARY) {
|
||||
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
|
||||
/* 对图片缩放 Zoom */
|
||||
if (action === MidjourneyActionEnum.ZOOM) {
|
||||
const actionDetail: any = await this.midjourneyService.getDrawActionDetail(action, drawId, orderId);
|
||||
const params = { ...body, userId: req.user.id, ...actionDetail };
|
||||
const res = await this.midjourneyService.addDrawQueue(params);
|
||||
const timeout = (await this.globalConfigService.getConfigs(['mjTimeoutMs'])) || 200000;
|
||||
const job = await this.mjDrawQueue.add('mjDraw', { id: res.id, action, userId: req.user.id }, { delay: 1000, timeout: +timeout });
|
||||
this.jobIds.push(job.id);
|
||||
// await this.userBalanceService.deductFromBalance(req.user.id, 'mjDraw', 4, 4);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/* 查询队列 */
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { GlobalConfigService } from '../globalConfig/globalConfig.service';
|
||||
import { GlobalConfigService } from './../globalConfig/globalConfig.service';
|
||||
import { Injectable, Logger, OnModuleInit } from '@nestjs/common';
|
||||
import { Cron, CronExpression } from '@nestjs/schedule';
|
||||
import { UserBalanceEntity } from '../userBalance/userBalance.entity';
|
||||
@@ -13,7 +13,7 @@ export class TaskService {
|
||||
private readonly userBalanceEntity: Repository<UserBalanceEntity>,
|
||||
private readonly globalConfigService: GlobalConfigService,
|
||||
private readonly modelsService: ModelsService,
|
||||
) {}
|
||||
) { }
|
||||
|
||||
/* 每小时刷新一次微信的token */
|
||||
@Cron(CronExpression.EVERY_HOUR)
|
||||
@@ -40,8 +40,8 @@ export class TaskService {
|
||||
}
|
||||
|
||||
/* 每小时检测一次授权 */
|
||||
@Cron('0 0 */5 * *')
|
||||
refreshBaiduAccesstoken() {
|
||||
this.modelsService.refreshBaiduAccesstoken();
|
||||
}
|
||||
// @Cron('0 0 */5 * *')
|
||||
// refreshBaiduAccesstoken() {
|
||||
// this.modelsService.refreshBaiduAccesstoken();
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -10,33 +10,46 @@ import * as FormData from 'form-data';
|
||||
|
||||
@Injectable()
|
||||
export class UploadService implements OnModuleInit {
|
||||
constructor(private readonly globalConfigService: GlobalConfigService) {}
|
||||
constructor(private readonly globalConfigService: GlobalConfigService) { }
|
||||
private tencentCos: any;
|
||||
|
||||
onModuleInit() {}
|
||||
onModuleInit() { }
|
||||
|
||||
async uploadFile(file) {
|
||||
const { filename: name, originalname, buffer, dir = 'ai', mimetype } = file;
|
||||
const fileTyle = mimetype ? mimetype.split('/')[1] : '';
|
||||
const filename = originalname || name
|
||||
Logger.debug(`准备上传文件: ${filename}, 类型: ${fileTyle}`, 'UploadService');
|
||||
|
||||
const {
|
||||
tencentCosStatus = 0,
|
||||
aliOssStatus = 0,
|
||||
cheveretoStatus = 0,
|
||||
} = await this.globalConfigService.getConfigs(['tencentCosStatus', 'aliOssStatus', 'cheveretoStatus']);
|
||||
|
||||
|
||||
Logger.debug(`上传配置状态 - 腾讯云: ${tencentCosStatus}, 阿里云: ${aliOssStatus}, Chevereto: ${cheveretoStatus}`, 'UploadService');
|
||||
|
||||
if (!Number(tencentCosStatus) && !Number(aliOssStatus) && !Number(cheveretoStatus)) {
|
||||
throw new HttpException('请先前往后台配置上传图片的方式', HttpStatus.BAD_REQUEST);
|
||||
}
|
||||
if (Number(tencentCosStatus)) {
|
||||
return this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle });
|
||||
}
|
||||
if (Number(aliOssStatus)) {
|
||||
return await this.uploadFileByAliOss({ filename, buffer, dir, fileTyle });
|
||||
}
|
||||
if (Number(cheveretoStatus)) {
|
||||
const { filename, buffer: fromBuffer, dir } = file;
|
||||
return await this.uploadFileByChevereto({ filename, buffer: fromBuffer.toString('base64'), dir, fileTyle });
|
||||
try {
|
||||
if (Number(tencentCosStatus)) {
|
||||
Logger.debug(`使用腾讯云COS上传`, 'UploadService');
|
||||
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle });
|
||||
}
|
||||
if (Number(aliOssStatus)) {
|
||||
Logger.debug(`使用阿里云OSS上传`, 'UploadService');
|
||||
return await this.uploadFileByAliOss({ filename, buffer, dir, fileTyle });
|
||||
}
|
||||
if (Number(cheveretoStatus)) {
|
||||
Logger.debug(`使用Chevereto上传`, 'UploadService');
|
||||
const { filename, buffer: fromBuffer, dir } = file;
|
||||
return await this.uploadFileByChevereto({ filename, buffer: fromBuffer.toString('base64'), dir, fileTyle });
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`上传失败: ${error.message}`, 'UploadService');
|
||||
throw error; // 重新抛出异常,以便调用方可以处理
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,23 +135,10 @@ export class UploadService implements OnModuleInit {
|
||||
this.tencentCos = new TENCENTCOS({ SecretId, SecretKey, FileParallelLimit: 10 });
|
||||
try {
|
||||
const proxyMj = (await this.globalConfigService.getConfigs(['mjProxy'])) || 0;
|
||||
/* 开启代理 */
|
||||
if (Number(proxyMj) === 1) {
|
||||
const data = { cosType: 'tencent', url, cosParams: { Bucket, Region, SecretId, SecretKey } };
|
||||
const mjProxyUrl = (await this.globalConfigService.getConfigs(['mjProxyUrl'])) || 'http://172.247.48.137:8000';
|
||||
const res = await axios.post(`${mjProxyUrl}/mj/replaceUpload`, data);
|
||||
if (!res.data) throw new HttpException('上传图片失败[ten][url]', HttpStatus.BAD_REQUEST);
|
||||
let locationUrl = res.data.replace(/^(http:\/\/|https:\/\/|\/\/|)(.*)/, 'https://$2');
|
||||
const { acceleratedDomain } = await this.getUploadConfig('tencent');
|
||||
if (acceleratedDomain) {
|
||||
locationUrl = locationUrl.replace(/^(https:\/\/[^/]+)(\/.*)$/, `https://${acceleratedDomain}$2`);
|
||||
console.log('当前已开启全球加速----------------->');
|
||||
}
|
||||
return locationUrl;
|
||||
} else {
|
||||
const buffer = await this.getBufferFromUrl(url);
|
||||
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle: '' });
|
||||
}
|
||||
|
||||
const buffer = await this.getBufferFromUrl(url);
|
||||
return await this.uploadFileByTencentCos({ filename, buffer, dir, fileTyle: '' });
|
||||
|
||||
} catch (error) {
|
||||
console.log('TODO->error: ', error);
|
||||
throw new HttpException('上传图片失败[ten][url]', HttpStatus.BAD_REQUEST);
|
||||
|
||||
Reference in New Issue
Block a user