Skip to content

Commit

Permalink
Merge pull request #67 from fdarian/drizzle-custom
Browse files Browse the repository at this point in the history
feat: new `drizzle.custom` directive
  • Loading branch information
fdarian authored Jul 5, 2024
2 parents 2bacafc + c53b52c commit 1458235
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 16 deletions.
61 changes: 45 additions & 16 deletions packages/generator/src/lib/adapter/fields/createField.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import type { DMMF } from '@prisma/generator-helper'
import { getDirective } from '~/lib/directive'
import { type ImportValue, namedImport } from '~/lib/syntaxes/imports'
import {
type ImportValue,
defaultImportValue,
namedImport,
} from '~/lib/syntaxes/imports'
import type { MakeRequired, ModifyType, Prettify } from '~/lib/types/utils'
import { getCustomDirective } from './directives/custom'

export type DefineImport = {
module: string
Expand All @@ -27,28 +32,52 @@ export function createField(input: CreateFieldInput) {

let func = `${input.func}`

const custom = getCustomDirective(field)
if (custom?.imports) {
imports = imports.concat(
custom.imports.map((def) =>
def.name.type === 'default-import'
? defaultImportValue(def.name.name, def.module, def.type ?? false)
: namedImport(def.name.names, def.module, def.type ?? false)
)
)
}

// .type<...>()
const customType = getCustomType(field)
if (customType) {
imports = imports.concat(customType.imports)
func += customType.code
if (custom?.$type) {
func += `.$type<${custom.$type}>()`
} else {
// Legacy `drizzle.type` directive
const customType = getCustomType(field)
if (customType) {
imports = imports.concat(customType.imports)
func += customType.code
}
}

const customDefault = getCustomDefault(field)
if (customDefault) {
imports = imports.concat(customDefault.imports)
func += customDefault.code
} else if (field.hasDefaultValue) {
const _field = field as FieldWithDefault
const def = input.onDefault?.(_field) ?? onDefault(_field)
if (def) {
imports = imports.concat(def.imports ?? [])
func += def.code
let hasDefaultFn = false
if (custom?.default) {
hasDefaultFn = true
func += `.$defaultFn(${custom.default})`
} else {
// Legacy `drizzle.default` directive
const customDefault = getCustomDefault(field)
if (customDefault) {
hasDefaultFn = true
imports = imports.concat(customDefault.imports)
func += customDefault.code
} else if (field.hasDefaultValue) {
const _field = field as FieldWithDefault
const def = input.onDefault?.(_field) ?? onDefault(_field)
if (def) {
imports = imports.concat(def.imports ?? [])
func += def.code
}
}
}

if (field.isId) func += input.onPrimaryKey?.(field) ?? '.primaryKey()'
else if (field.isRequired || field.hasDefaultValue || !!customDefault)
else if (field.isRequired || field.hasDefaultValue || hasDefaultFn)
func += '.notNull()'

return {
Expand Down
71 changes: 71 additions & 0 deletions packages/generator/src/lib/adapter/fields/directives/custom.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import type { DMMF } from '@prisma/generator-helper'
import * as v from 'valibot'
import getErrorMessage from '~/lib/error-message'

const DIRECTIVE = 'drizzle.custom'

export function getCustomDirective(field: DMMF.Field) {
const directiveInput = field.documentation
if (directiveInput == null || !directiveInput.startsWith(DIRECTIVE)) {
return
}

const parsing = v.safeParse(DirectiveSchema, parseJson(directiveInput))
if (!parsing.success)
throw new InvalidDirectiveShapeError({
input: directiveInput,
issues: parsing.issues,
})
return parsing.output
}

const NamedImportSchema = v.pipe(
v.array(v.string()),
v.transform((names) => ({
type: 'named-import' as const,
names,
}))
)

const DefaultImportSchema = v.pipe(
v.string(),
v.transform((name) => ({
type: 'default-import' as const,
name,
}))
)

const ImportSchema = v.object({
name: v.union([NamedImportSchema, DefaultImportSchema]),
/** e.g. "drizzle-orm" or "../my-type" */
module: v.string(),
/** Marks the import as a type import */
type: v.optional(v.boolean()),
})

const DirectiveSchema = v.object({
imports: v.optional(v.array(ImportSchema)),
$type: v.optional(v.string()),
default: v.optional(v.string()),
})

class InvalidDirectiveShapeError extends Error {
constructor(args: {
input: string
issues: [v.BaseIssue<unknown>, ...v.BaseIssue<unknown>[]]
}) {
super(
`Invalid ${DIRECTIVE} definition:\n\n— Error:${JSON.stringify(v.flatten(args.issues), null, 2)}\n—\n\n— Your Input\n${args.input}\n—`
)
}
}

function parseJson(directiveInput: string) {
try {
return JSON.parse(directiveInput.replace(DIRECTIVE, ''))
} catch (err) {
throw new Error(
`Invalid ${DIRECTIVE} JSON shape\n\n— Error:\n${getErrorMessage(err)}\n—\n\n— Your Input\n${directiveInput}\n—`
)
}
}
31 changes: 31 additions & 0 deletions packages/generator/src/lib/error-message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Based on https://kentcdodds.com/blog/get-a-catch-block-error-message-with-typescript
*/
export default function getErrorMessage(error: unknown) {
return toErrorWithMessage(error).message
}

type ErrorWithMessage = {
message: string
}

function isErrorWithMessage(error: unknown): error is ErrorWithMessage {
return (
typeof error === 'object' &&
error !== null &&
'message' in error &&
typeof (error as Record<string, unknown>).message === 'string'
)
}

function toErrorWithMessage(maybeError: unknown): ErrorWithMessage {
if (isErrorWithMessage(maybeError)) return maybeError

try {
return new Error(JSON.stringify(maybeError))
} catch {
// fallback in case there's an error stringifying the maybeError
// like with circular references for example.
return new Error(String(maybeError))
}
}
10 changes: 10 additions & 0 deletions packages/usage/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,13 @@ model Disambiguating_Currency {
}

// #endregion

model FieldCustomization {
id Int @id @default(autoincrement())
/// drizzle.custom {
/// "imports": [{ "name": ["SomeBigInt"], "module": "~/tests/shared/testFieldCustomization", "type": true }],
/// "$type": "SomeBigInt",
/// "default": "() => 1n"
/// }
allFields BigInt
}
2 changes: 2 additions & 0 deletions packages/usage/tests/mysql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { db } from 'src/lib/mysql'
import type { Db, Schema } from 'src/lib/types'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -25,3 +26,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
2 changes: 2 additions & 0 deletions packages/usage/tests/postgres.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { schema } from 'prisma/drizzle/schema'
import { db } from 'src/lib/postgres'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -20,3 +21,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
29 changes: 29 additions & 0 deletions packages/usage/tests/shared/testFieldCustomization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import fs, { readFileSync } from 'node:fs'
import type { TestContext } from 'tests/utils/types'
import { beforeAll, describe, expect, test } from 'vitest'

export type SomeBigInt = bigint

export const OUTPUT_FILE = './prisma/drizzle/field-customizations.ts'

export function testFieldCustomization({ db, schema, provider }: TestContext) {
let content: string
beforeAll(async () => {
expect(fs.existsSync(OUTPUT_FILE)).toBe(true)
content = readFileSync(OUTPUT_FILE, 'utf-8')
})

describe('allFields', () => {
test('should contain import', () => {
expect(content).include(
"import type { SomeBigInt } from '~/tests/shared/testFieldCustomization'"
)
})

test('should contain correct field definition', () => {
expect(content).include(
"allFields: bigint('allFields', { mode: 'bigint' }).$type<SomeBigInt>().$defaultFn(() => 1n)"
)
})
})
}
2 changes: 2 additions & 0 deletions packages/usage/tests/sqlite.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { db } from 'src/lib/sqlite'
import type { Db, Schema } from 'src/lib/types'
import { testDefault } from './shared/testDefault'
import { testDisambiguatingRelationship } from './shared/testDisambiguatingRelationship'
import { testFieldCustomization } from './shared/testFieldCustomization'
import { testFields } from './shared/testFields'
import { testIgnoreDecorator } from './shared/testIgnoreDecorator'
import { testManyToMany } from './shared/testManyToMany'
Expand All @@ -25,3 +26,4 @@ testDisambiguatingRelationship(ctx)
testSelfReferring(ctx)
testIgnoreDecorator(ctx)
testDefault(ctx)
testFieldCustomization(ctx)
3 changes: 3 additions & 0 deletions packages/usage/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"esModuleInterop": true,
"skipLibCheck": true,
"baseUrl": ".",
"paths": {
"~/tests/*": ["./tests/*"]
},
"moduleResolution": "Node",
"types": ["vitest/globals"]
}
Expand Down

0 comments on commit 1458235

Please sign in to comment.