Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Arcjet security (Shield, rate limit, bot detection) #69

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.local.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ NEXT_PUBLIC_CLERK_SIGN_UP_URL=/sign-up
NEXT_PUBLIC_CLERK_AFTER_SIGN_IN_URL=/
NEXT_PUBLIC_CLERK_AFTER_SIGN_UP_URL=/

# Arcjet related environment variables
ARCJET_KEY=ajkey_****

# OpenAI related environment variables
OPENAI_API_KEY=sk-****

Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Text Model: [OpenAI](https://platform.openai.com/docs/models)
- Text streaming: [ai sdk](https://github.com/vercel-labs/ai)
- Deployment: [Fly](https://fly.io/)
- Security: [Arcjet](https://arcjet.com/)

## Overview
- 🚀 [Quickstart](#quickstart)
Expand Down Expand Up @@ -76,6 +77,10 @@ e. **Supabase API key**
- `SUPABASE_PRIVATE_KEY` is the key starts with `ey` under Project API Keys
- Now, you should enable pgvector on Supabase and create a schema. You can do this easily by clicking on "SQL editor" on the left hand side on supabase UI and then clicking on "+New Query". Copy paste [this code snippet](https://github.com/a16z-infra/ai-getting-started/blob/main/pgvector.sql) in the SQL editor and click "Run".

f. **Arcjet key**

Visit https://app.arcjet.com to sign up for free and get your Arcjet key.

### 4. Generate embeddings

There are a few markdown files under `/blogs` directory as examples so you can do Q&A on them. To generate embeddings and store them in the vector database for future queries, you can run the following command:
Expand Down
117 changes: 117 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"generate-embeddings-supabase": "node src/scripts/indexBlogPGVector.mjs"
},
"dependencies": {
"@arcjet/next": "^1.0.0-alpha.13",
"@clerk/nextjs": "^4.21.9-snapshot.56dc3e",
"@headlessui/react": "^1.7.15",
"@pinecone-database/pinecone": "^0.1.6",
Expand Down
71 changes: 70 additions & 1 deletion src/app/api/qa-pg-vector/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,79 @@ import dotenv from "dotenv";
import { VectorDBQAChain } from "langchain/chains";
import { StreamingTextResponse, LangChainStream } from "ai";
import { CallbackManager } from "langchain/callbacks";
import { currentUser } from "@clerk/nextjs";
import arcjet, { shield, fixedWindow, detectBot } from "@arcjet/next";
import { NextResponse } from "next/server";

dotenv.config({ path: `.env.local` });

// The arcjet instance is created outside of the handler
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Arcjet Shield protects against common attacks e.g. SQL injection
shield({
mode: "LIVE",
}),
// Create a fixed window rate limit. Other algorithms are supported.
fixedWindow({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
characteristics: ["userId"], // Rate limit based on the Clerk userId
window: "60s", // 60 second fixed window
max: 10, // allow a maximum of 10 requests
}),
// Blocks all automated clients
detectBot({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["AUTOMATED"],
}),
],
});

export async function POST(req: Request) {
// Get the current user from Clerk
const user = await currentUser();
if (!user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}

// Use Arcjet to protect the route
const decision = await aj.protect(req, { userId: user.id });

if (decision.isDenied()) {
if (decision.reason.isRateLimit()) {
return NextResponse.json(
{
error: "Too Many Requests",
reason: decision.reason,
},
{
status: 429,
},
);
} else if (decision.reason.isBot()) {
return NextResponse.json(
{
error: "Bots are not allowed",
reason: decision.reason,
},
{
status: 403,
},
);
} else {
return NextResponse.json(
{
error: "Unauthorized",
reason: decision.reason,
},
{
status: 401,
},
);
}
}

const { prompt } = await req.json();

const privateKey = process.env.SUPABASE_PRIVATE_KEY;
Expand All @@ -31,7 +100,7 @@ export async function POST(req: Request) {
client,
tableName: "documents",
queryName: "match_documents",
}
},
);

const { stream, handlers } = LangChainStream();
Expand Down
71 changes: 70 additions & 1 deletion src/app/api/qa-pinecone/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,79 @@ import { OpenAI } from "langchain/llms/openai";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import { StreamingTextResponse, LangChainStream } from "ai";
import { CallbackManager } from "langchain/callbacks";
import { currentUser } from "@clerk/nextjs";
import arcjet, { shield, fixedWindow, detectBot } from "@arcjet/next";
import { NextResponse } from "next/server";

dotenv.config({ path: `.env.local` });

// The arcjet instance is created outside of the handler
const aj = arcjet({
key: process.env.ARCJET_KEY!, // Get your site key from https://app.arcjet.com
rules: [
// Arcjet Shield protects against common attacks e.g. SQL injection
shield({
mode: "LIVE",
}),
// Create a fixed window rate limit. Other algorithms are supported.
fixedWindow({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
characteristics: ["userId"], // Rate limit based on the Clerk userId
window: "60s", // 60 second fixed window
max: 10, // allow a maximum of 10 requests
}),
// Blocks all automated clients
detectBot({
mode: "LIVE", // will block requests. Use "DRY_RUN" to log only
block: ["AUTOMATED"],
}),
],
});

export async function POST(request: Request) {
// Get the current user from Clerk
const user = await currentUser();
if (!user) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}

// Use Arcjet to protect the route
const decision = await aj.protect(request, { userId: user.id });

if (decision.isDenied()) {
if (decision.reason.isRateLimit()) {
return NextResponse.json(
{
error: "Too Many Requests",
reason: decision.reason,
},
{
status: 429,
},
);
} else if (decision.reason.isBot()) {
return NextResponse.json(
{
error: "Bots are not allowed",
reason: decision.reason,
},
{
status: 403,
},
);
} else {
return NextResponse.json(
{
error: "Unauthorized",
reason: decision.reason,
},
{
status: 401,
},
);
}
}

const { prompt } = await request.json();
const client = new PineconeClient();
await client.init({
Expand All @@ -20,7 +89,7 @@ export async function POST(request: Request) {

const vectorStore = await PineconeStore.fromExistingIndex(
new OpenAIEmbeddings({ openAIApiKey: process.env.OPENAI_API_KEY }),
{ pineconeIndex }
{ pineconeIndex },
);

const { stream, handlers } = LangChainStream();
Expand Down
Loading
Loading