diff --git a/lib/passport/passport.serializer.ts b/lib/passport/passport.serializer.ts index 634e2d11..7d1ab997 100644 --- a/lib/passport/passport.serializer.ts +++ b/lib/passport/passport.serializer.ts @@ -1,13 +1,46 @@ +import { IncomingMessage } from 'http'; import * as passport from 'passport'; -export abstract class PassportSerializer { - abstract serializeUser(user: any, done: Function); - abstract deserializeUser(payload: any, done: Function); +export abstract class PassportSerializer< + UserType extends unknown = unknown, + PayloadType extends unknown = unknown, + RequestType extends IncomingMessage = IncomingMessage +> { + abstract serializeUser( + user: UserType, + req?: RequestType + ): Promise | PayloadType; + abstract deserializeUser( + payload: PayloadType, + req?: RequestType + ): Promise | UserType; constructor() { - passport.serializeUser((user, done) => this.serializeUser(user, done)); - passport.deserializeUser((payload, done) => - this.deserializeUser(payload, done) + passport.serializeUser( + async ( + req: RequestType, + user: UserType, + done: (err: unknown, payload?: PayloadType) => unknown + ) => { + try { + done(null, await this.serializeUser(user, req)); + } catch (err) { + done(err); + } + } + ); + passport.deserializeUser( + async ( + req: RequestType, + payload: PayloadType, + done: (err: unknown, user?: UserType) => unknown + ) => { + try { + done(null, await this.deserializeUser(payload, req)); + } catch (err) { + done(err); + } + } ); } }