Merge pull request #39 from sijinhui/update_auth

进一步优化认证流程
This commit is contained in:
sijinhui 2024-03-31 01:19:33 +08:00 committed by GitHub
commit f2d6a6dfbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 62 additions and 67 deletions

View File

@ -4,6 +4,7 @@ import { DEFAULT_MODELS, OPENAI_BASE_URL, GEMINI_BASE_URL } from "../constant";
import { collectModelTable } from "../utils/model"; import { collectModelTable } from "../utils/model";
import { makeAzurePath } from "../azure"; import { makeAzurePath } from "../azure";
import { getIP } from "@/app/api/auth"; import { getIP } from "@/app/api/auth";
import { getSessionName } from "@/lib/auth";
const serverConfig = getServerSideConfig(); const serverConfig = getServerSideConfig();
@ -149,36 +150,31 @@ export async function requestLog(
req: NextRequest, req: NextRequest,
jsonBody: any, jsonBody: any,
url_path: string, url_path: string,
name?: string,
) { ) {
// LOG // LOG
try { try {
if (url_path.startsWith("mj/") && !url_path.startsWith("mj/submit/")) { if (url_path.startsWith("mj/") && !url_path.startsWith("mj/submit/")) {
return; return;
} }
// const protocol = req.headers.get("x-forwarded-proto") || "http";
//const baseUrl = process.env.NEXTAUTH_URL ?? "http://localhost:3000";
const baseUrl = "http://localhost:3000"; const baseUrl = "http://localhost:3000";
const ip = getIP(req); const ip = getIP(req);
// 对其进行 Base64 解码
let h_userName = req.headers.get("x-request-name"); let { session, name } = await getSessionName();
if (h_userName) { console.log("[中文]", name, session, baseUrl);
const buffer = Buffer.from(h_userName, "base64");
h_userName = decodeURIComponent(buffer.toString("utf-8"));
}
console.log("[中文]", h_userName, baseUrl);
const logData = { const logData = {
ip: ip, ip: ip,
path: url_path, path: url_path,
logEntry: JSON.stringify(jsonBody), logEntry: JSON.stringify(jsonBody),
model: url_path.startsWith("mj/") ? "midjourney" : jsonBody?.model, // 后面尝试请求是添加到参数 model: url_path.startsWith("mj/") ? "midjourney" : jsonBody?.model, // 后面尝试请求是添加到参数
userName: h_userName, userName: name,
userID: session?.user?.id,
}; };
await fetch(`${baseUrl}/api/logs/openai`, { await fetch(`${baseUrl}/api/logs/openai`, {
method: "POST", method: "POST",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
// ...req.headers,
}, },
body: JSON.stringify(logData), body: JSON.stringify(logData),
}); });

View File

@ -1,29 +1,21 @@
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import prisma from "@/lib/prisma"; import prisma from "@/lib/prisma";
import { insertUser } from "@/lib/auth"; import { insertUser } from "@/lib/auth";
// import { getTokenLength } from "@/app/utils/token";
// import { Tiktoken } from "tiktoken/lite"
// import cl100k_base from "tiktoken/encoders/cl100k_base.json"
// import "tiktoken";
// import { get_encoding } from "tiktoken";
import { addHours, subMinutes } from "date-fns";
import { getTokenLength } from "@/lib/utils"; import { getTokenLength } from "@/lib/utils";
// function getTokenLength(input: string): number {
// const encoding = get_encoding("cl100k_base");
// // console.log('tokens: ', input, encoding.countTokens())
// return encoding.encode(input).length;
// }
async function handle( async function handle(
req: NextRequest, req: NextRequest,
{ params }: { params: { path: string[] } }, { params }: { params: { path: string[] } },
) { ) {
try { try {
const request_data = await req.json(); const request_data = await req.json();
if (request_data?.userName) { console.log("log", request_data);
await insertUser({ name: request_data?.userName }); // if (request_data?.userName) {
} // await insertUser({
// name: request_data?.userName,
// email: request_data?.userName,
// });
// }
// console.log("===========4", request_data); // console.log("===========4", request_data);
try { try {
if (request_data?.logEntry) { if (request_data?.logEntry) {

View File

@ -8,10 +8,8 @@ import {
} from "@/app/constant"; } from "@/app/constant";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import { NextRequest, NextResponse } from "next/server"; import { NextRequest, NextResponse } from "next/server";
import { auth, getIP } from "../../auth"; import { auth } from "../../auth";
import { getToken } from "next-auth/jwt";
import { requestLog, requestOpenai } from "../../common"; import { requestLog, requestOpenai } from "../../common";
import { headers } from "next/headers";
const ALLOWD_PATH = new Set(Object.values({ ...OpenaiPath, ...AZURE_PATH })); const ALLOWD_PATH = new Set(Object.values({ ...OpenaiPath, ...AZURE_PATH }));
@ -113,7 +111,7 @@ async function handle(
export const GET = handle; export const GET = handle;
export const POST = handle; export const POST = handle;
export const runtime = "edge"; // export const runtime = "edge";
export const preferredRegion = [ export const preferredRegion = [
"arn1", "arn1",
"bom1", "bom1",

View File

@ -2,10 +2,10 @@ import {getServerSession, type NextAuthOptions, Theme} from "next-auth";
import GitHubProvider from "next-auth/providers/github"; import GitHubProvider from "next-auth/providers/github";
import EmailProvider from "next-auth/providers/email"; import EmailProvider from "next-auth/providers/email";
import CredentialsProvider from "next-auth/providers/credentials"; import CredentialsProvider from "next-auth/providers/credentials";
import { PrismaAdapter } from "@next-auth/prisma-adapter"; import {PrismaAdapter} from "@next-auth/prisma-adapter";
import prisma from "@/lib/prisma"; import prisma from "@/lib/prisma";
import { isEmail, isName } from "@/lib/auth_list"; import {isEmail, isName} from "@/lib/auth_list";
import { createTransport } from "nodemailer"; import {createTransport} from "nodemailer";
const SECURE_COOKIES:boolean = !!process.env.SECURE_COOKIES; const SECURE_COOKIES:boolean = !!process.env.SECURE_COOKIES;
@ -82,15 +82,13 @@ export const authOptions: NextAuthOptions = {
// 判断姓名格式是否符合要求,不符合则拒绝 // 判断姓名格式是否符合要求,不符合则拒绝
if (username && isName(username)) { if (username && isName(username)) {
// Any object returned will be saved in `user` property of the JWT // Any object returned will be saved in `user` property of the JWT
let user:{[key: string]: string} = { let user:{[key: string]: string} = {}
name: username,
// email: null
}
if (isEmail(username)) { if (isEmail(username)) {
user['email'] = username; user['email'] = username;
} else {
user['name'] = username;
} }
await insertUser(user); return await insertUser(user) ?? user
return user
} else { } else {
// If you return null then an error will be displayed advising the user to check their details. // If you return null then an error will be displayed advising the user to check their details.
// return null // return null
@ -125,7 +123,7 @@ export const authOptions: NextAuthOptions = {
callbacks: { callbacks: {
jwt: async ({ token, user }) => { jwt: async ({ token, user }) => {
// const current_time = Math.floor(Date.now() / 1000); // const current_time = Math.floor(Date.now() / 1000);
// console.log('=============', token, user, current_time) console.log('=============', token, user,)
if (user) { if (user) {
token.user = user; token.user = user;
} }
@ -139,6 +137,7 @@ export const authOptions: NextAuthOptions = {
// @ts-expect-error // @ts-expect-error
username: token?.user?.username || token?.user?.gh_username, username: token?.user?.username || token?.user?.gh_username,
}; };
console.log('555555555,', session, token)
return session; return session;
}, },
}, },
@ -157,6 +156,15 @@ export function getSession() {
} | null>; } | null>;
} }
export async function getSessionName() {
const session = await getSession();
console.log('in........', session)
return {
name: session?.user?.email || session?.user?.name,
session
}
}
// export function withSiteAuth(action: any) { // export function withSiteAuth(action: any) {
// return async ( // return async (
// formData: FormData | null, // formData: FormData | null,
@ -226,21 +234,22 @@ export async function insertUser(user: {[key: string]: string}) {
} }
const existingUser = conditions.length? await prisma.user.findFirst({ const existingUser = conditions.length? await prisma.user.findFirst({
where: { where: {
OR: conditions, AND: conditions,
}, },
}) : null; }) : null;
// console.log('[LOG]', existingUser, user, '=======') // console.log('[LOG]', existingUser, user, '=======')
if (!existingUser) { if (!existingUser) {
const newUser = await prisma.user.create({ return await prisma.user.create({
data: user data: user
}) })
// console.log('[LOG]', user, '=======') } else {
console.log('user==========', existingUser)
return existingUser;
} }
} catch (e) { } catch (e) {
console.log('[Prisma Error]', e); console.log('[Prisma Error]', e);
return false; return false;
} }
return true;
} }

View File

@ -1,8 +1,7 @@
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import type { NextRequest } from "next/server"; import type { NextRequest } from "next/server";
import { getToken } from "next-auth/jwt"; import { getToken } from "next-auth/jwt";
import { DENY_LIST, isName, ADMIN_LIST } from "@/lib/auth_list"; import { isName, ADMIN_LIST } from "@/lib/auth_list";
import {use} from "react";
export default async function middleware(req: NextRequest) { export default async function middleware(req: NextRequest) {
const url = req.nextUrl; const url = req.nextUrl;
@ -60,23 +59,23 @@ export default async function middleware(req: NextRequest) {
); );
} }
if (req.method == 'POST' && (path.startsWith("/api/openai/") || path.startsWith("/api/midjourney"))) { // if (req.method == 'POST' && (path.startsWith("/api/openai/") || path.startsWith("/api/midjourney"))) {
// 重写header添加用户名 // // 重写header添加用户名
// console.log(session,'========') // // console.log(session,'========')
const requestHeaders = new Headers(req.headers) // const requestHeaders = new Headers(req.headers)
//
// 使用 encodeURIComponent 对特殊字符进行编码 // // 使用 encodeURIComponent 对特殊字符进行编码
// 将编码的 URI 组件转换成 Base64 // // 将编码的 URI 组件转换成 Base64
const encodeName = Buffer.from(encodeURIComponent(`${session?.name}`)).toString('base64'); // const encodeName = Buffer.from(encodeURIComponent(`${session?.name}`)).toString('base64');
//
requestHeaders.set('x-request-name', encodeName) // requestHeaders.set('x-request-name', encodeName)
return NextResponse.next({ // return NextResponse.next({
request: { // request: {
// New request headers // // New request headers
headers: requestHeaders, // headers: requestHeaders,
}, // },
}) // })
} // }
return NextResponse.next() return NextResponse.next()
} }

View File

@ -16,8 +16,8 @@ model User {
id String @id @default(cuid()) id String @id @default(cuid())
name String? @unique name String? @unique
// if you are using Github OAuth, you can get rid of the username attribute (that is for Twitter OAuth) // if you are using Github OAuth, you can get rid of the username attribute (that is for Twitter OAuth)
username String? username String? @unique
gh_username String? gh_username String? @unique
email String? @unique email String? @unique
emailVerified DateTime? emailVerified DateTime?
image String? image String?
@ -70,10 +70,11 @@ model LogEntry {
path String? @db.Text path String? @db.Text
model String? @db.VarChar(25) model String? @db.VarChar(25)
userName String? @db.VarChar(50) userName String? @db.VarChar(50)
userID String?
createdAt DateTime @default(now()) createdAt DateTime @default(now())
// logEntry String? @db.Text // logEntry String? @db.Text
logToken Int? @default(0) logToken Int? @default(0)
user User? @relation(fields: [userName], references: [name], onDelete: NoAction) user User? @relation(fields: [userID], references: [id], onDelete: NoAction)
} }
model VerificationToken { model VerificationToken {