Skip to content

Commit

Permalink
Flav/augment get server side props (#4140)
Browse files Browse the repository at this point in the history
* Add `withGetServerSidePropsRequirements` wrapper

* 👕

* ✨

* ✂️

* 👕

* Address comments from review

* Tmp

* Augment getServerSideProps with session

* ✨

* ✨

* ✂️

* 🔙

* 👕
  • Loading branch information
flvndvd authored Mar 5, 2024
1 parent 2e86a45 commit 2586ca0
Show file tree
Hide file tree
Showing 60 changed files with 530 additions and 515 deletions.
3 changes: 2 additions & 1 deletion front/lib/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
NextApiRequest,
NextApiResponse,
} from "next";
import type { Session } from "next-auth";
import { getServerSession } from "next-auth/next";

import { isDevelopment } from "@app/lib/development";
Expand Down Expand Up @@ -446,7 +447,7 @@ export class Authenticator {
export async function getSession(
req: NextApiRequest | GetServerSidePropsContext["req"],
res: NextApiResponse | GetServerSidePropsContext["res"]
): Promise<any | null> {
): Promise<Session | null> {
return getServerSession(req, res, authOptions);
}

Expand Down
56 changes: 56 additions & 0 deletions front/lib/iam/provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import type { UserProviderType } from "@dust-tt/types";

interface LegacyProvider {
provider: UserProviderType;
id: number | string;
}

interface LegacyExternalUser {
name: string;
email: string;
image?: string;
username?: string;
email_verified?: boolean;
}

interface LegacySession {
provider: LegacyProvider;
user: LegacyExternalUser;
}

function isLegacyExternalUser(user: unknown): user is LegacyExternalUser {
return (
typeof user === "object" &&
user !== null &&
"email" in user &&
"name" in user
);
}

function isLegacyProvider(provider: unknown): provider is LegacyProvider {
return (
typeof provider === "object" &&
provider !== null &&
"provider" in provider &&
"id" in provider
);
}

export function isLegacySession(session: unknown): session is LegacySession {
return (
typeof session === "object" &&
session !== null &&
"provider" in session &&
isLegacyProvider(session.provider) &&
"user" in session &&
isLegacyExternalUser(session.user)
);
}

// We only expose generic types to ease phasing out.

export type Session = LegacySession;

export function isValidSession(session: unknown): session is Session {
return isLegacySession(session);
}
70 changes: 44 additions & 26 deletions front/lib/iam/session.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import type { RoleType, UserTypeWithWorkspaces } from "@dust-tt/types";
import type {
GetServerSideProps,
GetServerSidePropsContext,
GetServerSidePropsResult,
PreviewData,
} from "next";
import type { ParsedUrlQuery } from "querystring";
import { Op } from "sequelize";

import { getSession } from "@app/lib/auth";
import type { Session } from "@app/lib/iam/provider";
import { isValidSession } from "@app/lib/iam/provider";
import {
fetchUserFromSession,
maybeUpdateFromExternalUser,
Expand Down Expand Up @@ -84,31 +86,39 @@ export async function getUserFromSession(
};
}

interface WithGetServerSidePropsRequirementsOptions {
interface MakeGetServerSidePropsRequirementsWrapperOptions<
R extends boolean = true
> {
enableLogging?: boolean;
requireAuth?: boolean;
requireAuth: R;
}

const defaultWithGetServerSidePropsRequirements: WithGetServerSidePropsRequirementsOptions =
{
enableLogging: true,
requireAuth: true,
};
export type CustomGetServerSideProps<
Props extends { [key: string]: any } = { [key: string]: any },
Params extends ParsedUrlQuery = ParsedUrlQuery,
Preview extends PreviewData = PreviewData,
RequireAuth extends boolean = true
> = (
context: GetServerSidePropsContext<Params, Preview>,
session: RequireAuth extends true ? Session : null
) => Promise<GetServerSidePropsResult<Props>>;

export function withGetServerSidePropsRequirements<
T extends { [key: string]: any } = { [key: string]: any }
>(
getServerSideProps: GetServerSideProps<T>,
opts: WithGetServerSidePropsRequirementsOptions = defaultWithGetServerSidePropsRequirements
): GetServerSideProps<T> {
return async (
context: GetServerSidePropsContext<ParsedUrlQuery, PreviewData>
export function makeGetServerSidePropsRequirementsWrapper<
RequireAuth extends boolean = true
>({
enableLogging = true,
requireAuth,
}: MakeGetServerSidePropsRequirementsWrapperOptions<RequireAuth>) {
return <T extends { [key: string]: any } = { [key: string]: any }>(
getServerSideProps: CustomGetServerSideProps<T, any, any, RequireAuth>
) => {
const { enableLogging, requireAuth } = opts;

if (requireAuth) {
const session = await getSession(context.req, context.res);
if (!session) {
return async (
context: GetServerSidePropsContext<ParsedUrlQuery, PreviewData>
) => {
const session = requireAuth
? await getSession(context.req, context.res)
: null;
if (requireAuth && (!session || !isValidSession(session))) {
return {
redirect: {
permanent: false,
Expand All @@ -117,12 +127,20 @@ export function withGetServerSidePropsRequirements<
},
};
}
}

if (enableLogging) {
return withGetServerSidePropsLogging(getServerSideProps)(context);
}
const userSession = session as RequireAuth extends true ? Session : null;

return getServerSideProps(context);
if (enableLogging) {
return withGetServerSidePropsLogging(getServerSideProps)(
context,
userSession
);
}

return getServerSideProps(context, userSession);
};
};
}

export const withDefaultGetServerSidePropsRequirements =
makeGetServerSidePropsRequirementsWrapper({ requireAuth: true });
5 changes: 3 additions & 2 deletions front/lib/iam/users.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import type { UserProviderType } from "@dust-tt/types";
import type { Session } from "next-auth";

import { isGoogleSession } from "@app/lib/iam/session";
import { User } from "@app/lib/models/user";
import { guessFirstandLastNameFromFullName } from "@app/lib/user";

interface LegacyProviderInfo {
provider: "google" | "github";
providerId: string;
provider: UserProviderType;
providerId: number | string;
}

async function fetchUserWithLegacyProvider({
Expand Down
26 changes: 11 additions & 15 deletions front/logger/withlogging.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@ import type {
} from "@dust-tt/types";
import tracer from "dd-trace";
import StatsD from "hot-shots";
import type {
GetServerSideProps,
GetServerSidePropsContext,
NextApiRequest,
NextApiResponse,
PreviewData,
} from "next";
import type { ParsedUrlQuery } from "querystring";
import type { NextApiRequest, NextApiResponse } from "next";

import type { CustomGetServerSideProps } from "@app/lib/iam/session";

import logger from "./logger";

Expand Down Expand Up @@ -148,12 +143,13 @@ export function apiError<T>(
return;
}

export function withGetServerSidePropsLogging<T extends { [key: string]: any }>(
getServerSideProps: GetServerSideProps<T>
): GetServerSideProps<T> {
return async (
context: GetServerSidePropsContext<ParsedUrlQuery, PreviewData>
) => {
export function withGetServerSidePropsLogging<
T extends { [key: string]: any },
RequireAuth extends boolean = true
>(
getServerSideProps: CustomGetServerSideProps<T, any, any, RequireAuth>
): CustomGetServerSideProps<T, any, any, RequireAuth> {
return async (context, session) => {
const now = new Date();

let route = context.resolvedUrl.split("?")[0];
Expand All @@ -165,7 +161,7 @@ export function withGetServerSidePropsLogging<T extends { [key: string]: any }>(
}

try {
const res = await getServerSideProps(context);
const res = await getServerSideProps(context, session);

const elapsed = new Date().getTime() - now.getTime();

Expand Down
44 changes: 21 additions & 23 deletions front/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,39 @@ import ScrollingHeader from "@app/components/home/scrollingHeader";
import { PricePlans } from "@app/components/PlansTables";
import { getSession } from "@app/lib/auth";
import { getUserFromSession } from "@app/lib/iam/session";
import { withGetServerSidePropsRequirements } from "@app/lib/iam/session";
import { makeGetServerSidePropsRequirementsWrapper } from "@app/lib/iam/session";
import { classNames } from "@app/lib/utils";

const { GA_TRACKING_ID = "" } = process.env;

export const getServerSideProps = withGetServerSidePropsRequirements<{
export const getServerSideProps = makeGetServerSidePropsRequirementsWrapper({
requireAuth: false,
})<{
gaTrackingId: string;
}>(
async (context) => {
const session = await getSession(context.req, context.res);
const user = await getUserFromSession(session);
}>(async (context) => {
// Fetch session explicitly as this page redirects logged in users to our home page.
const session = await getSession(context.req, context.res);
const user = await getUserFromSession(session);

if (user && user.workspaces.length > 0) {
let url = `/w/${user.workspaces[0].sId}`;
if (user && user.workspaces.length > 0) {
let url = `/w/${user.workspaces[0].sId}`;

if (context.query.inviteToken) {
url = `/api/login?inviteToken=${context.query.inviteToken}`;
}

return {
redirect: {
destination: url,
permanent: false,
},
};
if (context.query.inviteToken) {
url = `/api/login?inviteToken=${context.query.inviteToken}`;
}

return {
props: { gaTrackingId: GA_TRACKING_ID },
redirect: {
destination: url,
permanent: false,
},
};
},
{
requireAuth: false,
}
);

return {
props: { gaTrackingId: GA_TRACKING_ID },
};
});

export default function Home({
gaTrackingId,
Expand Down
29 changes: 13 additions & 16 deletions front/pages/login-error.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,25 @@ import { Button, Logo } from "@dust-tt/sparkle";
import type { InferGetServerSidePropsType } from "next";
import Link from "next/link";

import { withGetServerSidePropsRequirements } from "@app/lib/iam/session";
import { makeGetServerSidePropsRequirementsWrapper } from "@app/lib/iam/session";

const { URL = "", GA_TRACKING_ID = "" } = process.env;

export const getServerSideProps = withGetServerSidePropsRequirements<{
export const getServerSideProps = makeGetServerSidePropsRequirementsWrapper({
requireAuth: false,
})<{
domain?: string;
gaTrackingId: string;
baseUrl: string;
}>(
async (context) => {
return {
props: {
domain: context.query.domain as string,
baseUrl: URL,
gaTrackingId: GA_TRACKING_ID,
},
};
},
{
requireAuth: false,
}
);
}>(async (context) => {
return {
props: {
domain: context.query.domain as string,
baseUrl: URL,
gaTrackingId: GA_TRACKING_ID,
},
};
});

export default function LoginError({
domain,
Expand Down
12 changes: 6 additions & 6 deletions front/pages/no-workspace.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import type { UserTypeWithWorkspaces } from "@dust-tt/types";
import type { InferGetServerSidePropsType } from "next";
import { useRouter } from "next/router";

import { getSession } from "@app/lib/auth";
import { getUserFromSession } from "@app/lib/iam/session";
import { withGetServerSidePropsRequirements } from "@app/lib/iam/session";
import {
getUserFromSession,
withDefaultGetServerSidePropsRequirements,
} from "@app/lib/iam/session";
import { Membership, Workspace, WorkspaceHasDomain } from "@app/lib/models";
import logger from "@app/logger/logger";

Expand Down Expand Up @@ -55,13 +56,12 @@ async function fetchRevokedWorkspace(
return Workspace.findByPk(revokedWorkspaceId);
}

export const getServerSideProps = withGetServerSidePropsRequirements<{
export const getServerSideProps = withDefaultGetServerSidePropsRequirements<{
status: "auto-join-disabled" | "revoked";
userFirstName: string;
workspaceName: string;
workspaceVerifiedDomain: string | null;
}>(async (context) => {
const session = await getSession(context.req, context.res);
}>(async (context, session) => {
const user = await getUserFromSession(session);

if (!user) {
Expand Down
9 changes: 4 additions & 5 deletions front/pages/poke/[wId]/assistants/[aId]/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import type { InferGetServerSidePropsType } from "next";

import PokeNavbar from "@app/components/poke/PokeNavbar";
import { getAgentConfigurations } from "@app/lib/api/assistant/configuration";
import { Authenticator, getSession } from "@app/lib/auth";
import { withGetServerSidePropsRequirements } from "@app/lib/iam/session";
import { Authenticator } from "@app/lib/auth";
import { withDefaultGetServerSidePropsRequirements } from "@app/lib/iam/session";

export const getServerSideProps = withGetServerSidePropsRequirements<{
export const getServerSideProps = withDefaultGetServerSidePropsRequirements<{
agentConfigurations: AgentConfigurationType[];
}>(async (context) => {
const session = await getSession(context.req, context.res);
}>(async (context, session) => {
const auth = await Authenticator.fromSuperUserSession(
session,
context.params?.wId as string
Expand Down
Loading

0 comments on commit 2586ca0

Please sign in to comment.