Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new drizzle.custom directive #67

Merged
merged 19 commits into from
Jul 5, 2024
61 changes: 45 additions & 16 deletions packages/generator/src/lib/adapter/fields/createField.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import type { DMMF } from '@prisma/generator-helper'
import { getDirective } from '~/lib/directive'
import { type ImportValue, namedImport } from '~/lib/syntaxes/imports'
import {
type ImportValue,
defaultImportValue,
namedImport,
} from '~/lib/syntaxes/imports'
import type { MakeRequired, ModifyType, Prettify } from '~/lib/types/utils'
import { getCustomDirective } from './directives/custom'

export type DefineImport = {
module: string
Expand All @@ -27,28 +32,52 @@ export function createField(input: CreateFieldInput) {

let func = `${input.func}`

const custom = getCustomDirective(field)
if (custom?.imports) {
imports = imports.concat(
custom.imports.map((def) =>
def.name.type === 'default-import'
? defaultImportValue(def.name.name, def.module, def.type ?? false)
: namedImport(def.name.names, def.module, def.type ?? false)
)
)
}

// .type<...>()
const customType = getCustomType(field)
if (customType) {
imports = imports.concat(customType.imports)
func += customType.code
if (custom?.$type) {
func += `.$type<${custom.$type}>()`
} else {
// Legacy `drizzle.type` directive
const customType = getCustomType(field)
if (customType) {
imports = imports.concat(customType.imports)
func += customType.code
}
}

const customDefault = getCustomDefault(field)
if (customDefault) {
imports = imports.concat(customDefault.imports)
func += customDefault.code
} else if (field.hasDefaultValue) {
const _field = field as FieldWithDefault
const def = input.onDefault?.(_field) ?? onDefault(_field)
if (def) {
imports = imports.concat(def.imports ?? [])
func += def.code
let hasDefaultFn = false
if (custom?.default) {
hasDefaultFn = true
func += `.$defaultFn(${custom.default})`
} else {
// Legacy `drizzle.default` directive
const customDefault = getCustomDefault(field)
if (customDefault) {
hasDefaultFn = true
imports = imports.concat(customDefault.imports)
func += customDefault.code
} else if (field.hasDefaultValue) {
const _field = field as FieldWithDefault
const def = input.onDefault?.(_field) ?? onDefault(_field)
if (def) {
imports = imports.concat(def.imports ?? [])
func += def.code
}
}
}

if (field.isId) func += input.onPrimaryKey?.(field) ?? '.primaryKey()'
else if (field.isRequired || field.hasDefaultValue || !!customDefault)
else if (field.isRequired || field.hasDefaultValue || hasDefaultFn)
func += '.notNull()'

return {
Expand Down
71 changes: 71 additions & 0 deletions packages/generator/src/lib/adapter/fields/directives/custom.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import type { DMMF } from '@prisma/generator-helper'
import * as v from 'valibot'
import getErrorMessage from '~/lib/error-message'

const DIRECTIVE = 'drizzle.custom'

export function getCustomDirective(field: DMMF.Field) {
const directiveInput = field.documentation
if (directiveInput == null || !directiveInput.startsWith(DIRECTIVE)) {
return
}

const parsing = v.safeParse(DirectiveSchema, parseJson(directiveInput))
if (!parsing.success)
throw new InvalidDirectiveShapeError({
input: directiveInput,
issues: parsing.issues,
})
return parsing.output
}

const NamedImportSchema = v.pipe(
v.array(v.string()),
v.transform((names) => ({
type: 'named-import' as const,
names,
}))
)

const DefaultImportSchema = v.pipe(
v.string(),
v.transform((name) => ({
type: 'default-import' as const,
name,
}))
)

const ImportSchema = v.object({
name: v.union([NamedImportSchema, DefaultImportSchema]),
/** e.g. "drizzle-orm" or "../my-type" */
module: v.string(),
/** Marks the import as a type import */
type: v.optional(v.boolean()),
})

const DirectiveSchema = v.object({
imports: v.optional(v.array(ImportSchema)),
$type: v.optional(v.string()),
default: v.optional(v.string()),
})

class InvalidDirectiveShapeError extends Error {
constructor(args: {
input: string
issues: [v.BaseIssue<unknown>, ...v.BaseIssue<unknown>[]]
}) {
super(
`Invalid ${DIRECTIVE} definition:\n\n— Error:${JSON.stringify(v.flatten(args.issues), null, 2)}\n—\n\n— Your Input\n${args.input}\n—`
)
}
}

function parseJson(directiveInput: string) {
try {
return JSON.parse(directiveInput.replace(DIRECTIVE, ''))
} catch (err) {
throw new Error(
`Invalid ${DIRECTIVE} JSON shape\n\n— Error:\n${getErrorMessage(err)}\n—\n\n— Your Input\n${directiveInput}\n—`
)
}
}
31 changes: 31 additions & 0 deletions packages/generator/src/lib/error-message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Based on https://kentcdodds.com/blog/get-a-catch-block-error-message-with-typescript
*/
export default function getErrorMessage(error: unknown) {
return toErrorWithMessage(error).message
}

type ErrorWithMessage = {
message: string
}

function isErrorWithMessage(error: unknown): error is ErrorWithMessage {
return (
typeof error === 'object' &&
error !== null &&
'message' in error &&
typeof (error as Record<string, unknown>).message === 'string'
)
}

function toErrorWithMessage(maybeError: unknown): ErrorWithMessage {
if (isErrorWithMessage(maybeError)) return maybeError

try {
return new Error(JSON.stringify(maybeError))
} catch {
// fallback in case there's an error stringifying the maybeError
// like with circular references for example.
return new Error(String(maybeError))
}
}
10 changes: 10 additions & 0 deletions packages/usage/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,13 @@ model Disambiguating_Currency {
}

// #endregion

model FieldCustomization {
id Int @id @default(autoincrement())
/// drizzle.custom {
/// "imports": [{ "name": ["SomeBigInt"], "module": "~/tests/shared/testFieldCustomization", "type": true }],
/// "$type": "SomeBigInt",
/// "default": "() => 1n"
/// }
allFields BigInt
}
2 changes: 2 additions & 0 deletions packages/usage/tests/mysql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { db } from 'src/lib/mysql'
import type { Db, Schema } from 'src/lib/types'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -25,3 +26,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
2 changes: 2 additions & 0 deletions packages/usage/tests/postgres.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { schema } from 'prisma/drizzle/schema'
import { db } from 'src/lib/postgres'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -20,3 +21,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
29 changes: 29 additions & 0 deletions packages/usage/tests/shared/testFieldCustomization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import fs, { readFileSync } from 'node:fs'
import type { TestContext } from 'tests/utils/types'
import { beforeAll, describe, expect, test } from 'vitest'

export type SomeBigInt = bigint

export const OUTPUT_FILE = './prisma/drizzle/field-customizations.ts'

export function testFieldCustomization({ db, schema, provider }: TestContext) {
let content: string
beforeAll(async () => {
expect(fs.existsSync(OUTPUT_FILE)).toBe(true)
content = readFileSync(OUTPUT_FILE, 'utf-8')
})

describe('allFields', () => {
test('should contain import', () => {
expect(content).include(
"import type { SomeBigInt } from '~/tests/shared/testFieldCustomization'"
)
})

test('should contain correct field definition', () => {
expect(content).include(
"allFields: bigint('allFields', { mode: 'bigint' }).$type<SomeBigInt>().$defaultFn(() => 1n)"
)
})
})
}
2 changes: 2 additions & 0 deletions packages/usage/tests/sqlite.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { db } from 'src/lib/sqlite'
import type { Db, Schema } from 'src/lib/types'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -25,3 +26,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
3 changes: 3 additions & 0 deletions packages/usage/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"esModuleInterop": true,
"skipLibCheck": true,
"baseUrl": ".",
"paths": {
"~/tests/*": ["./tests/*"]
},
"moduleResolution": "Node",
"types": ["vitest/globals"]
}
Expand Down