Skip to content

Commit

Permalink
feat: Support for dbgenerated default (#27)
Browse files Browse the repository at this point in the history
* Handle nested directive

* Add test for dbgenerated

* Allow sending import from onDefault

* Pass dbgenerated field

* Update func

---------

Co-authored-by: farreldarian <[email protected]>
  • Loading branch information
fdarian and farreldarian authored Jan 20, 2024
1 parent 7996651 commit f43f345
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 33 deletions.
Binary file modified bun.lockb
Binary file not shown.
56 changes: 38 additions & 18 deletions packages/generator/src/lib/adapter/fields/createField.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ interface CreateFieldInput {
field: DMMF.Field
imports?: ImportValue[]
func: string
onDefault?: (field: FieldWithDefault) => string | undefined
onDefault?: (
field: FieldWithDefault
) => { code: string; imports?: ImportValue[] } | undefined
onPrimaryKey?: (field: DMMF.Field) => string | undefined
}

Expand All @@ -37,7 +39,11 @@ export function createField(input: CreateFieldInput) {
func += customDefault.code
} else if (field.hasDefaultValue) {
const _field = field as FieldWithDefault
func += input.onDefault?.(_field) ?? onDefault(_field)
const def = input.onDefault?.(_field) ?? onDefault(_field)
if (def) {
imports = imports.concat(def.imports ?? [])
func += def.code
}
}

if (field.isId) func += input.onPrimaryKey?.(field) ?? '.primaryKey()'
Expand Down Expand Up @@ -146,47 +152,61 @@ function isDefaultScalarList(
return Array.isArray(field.default)
}

function onDefault(field: FieldWithDefault) {
if (
isDefaultFunc(field) &&
field.type === 'DateTime' &&
field.default.name === 'now'
) {
return `.defaultNow()`
function onDefault(
field: FieldWithDefault
): { imports?: ImportValue[]; code: string } | undefined {
if (isDefaultFunc(field)) {
if (field.default.name === 'dbgenerated') {
return {
imports: [namedImport(['sql'], 'drizzle-orm')],
code: `.default(sql\`${field.default.args[0]}\`)`,
}
}

if (field.type === 'DateTime' && field.default.name === 'now') {
return { code: '.defaultNow()' }
}
}

if (isDefaultScalar(field)) {
if (field.type === 'Bytes') {
return `.$defaultFn(() => Buffer.from('${field.default}', 'base64'))`
return {
code: `.$defaultFn(() => Buffer.from('${field.default}', 'base64'))`,
}
}

const defaultDef = getDefaultScalarDefinition(field, field.default)

if (defaultDef == null) return ''
return `.default(${defaultDef})`
if (defaultDef == null) return
return {
code: `.default(${defaultDef})`,
}
}

if (isDefaultScalarList(field)) {
if (field.type === 'Bytes') {
return `.$defaultFn(() => [ ${field.default
.map((value) => `Buffer.from('${value}', 'base64')`)
.join(', ')} ])`
return {
code: `.$defaultFn(() => [ ${field.default
.map((value) => `Buffer.from('${value}', 'base64')`)
.join(', ')} ])`,
}
}

const defaultDefs = field.default.map((value) =>
getDefaultScalarDefinition(field, value)
)

if (defaultDefs.some((val) => val == null)) return ''
return `.default([${defaultDefs.join(', ')}])`
if (defaultDefs.some((val) => val == null)) return
return {
code: `.default([${defaultDefs.join(', ')}])`,
}
}

console.warn(
`Unsupported default value: ${JSON.stringify(field.default)} on field ${
field.name
}`
)
return ''
}

function getDefaultScalarDefinition(
Expand Down
33 changes: 19 additions & 14 deletions packages/generator/src/lib/adapter/providers/mysql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,29 @@ export const mysqlAdapter = createAdapter({
},
// https://orm.drizzle.team/docs/column-types/mysql#datetime
DateTime(field) {
const hasDefaultNow =
hasDefault(field) &&
isDefaultFunc(field) &&
field.default.name === 'now'

return createField({
field,
imports: [
namedImport(['datetime'], coreModule),
...(hasDefaultNow ? [namedImport(['sql'], 'drizzle-orm')] : []),
],
imports: [namedImport(['datetime'], coreModule)],
func: `datetime('${getDbName(field)}', { mode: 'date', fsp: 3 })`,
// https://github.com/drizzle-team/drizzle-orm/issues/921
onDefault: (field) => {
if (hasDefaultNow) {
return `.default(sql\`CURRENT_TIMESTAMP(3)\`)`
if (
hasDefault(field) &&
isDefaultFunc(field) &&
field.default.name === 'now'
) {
return {
imports: [namedImport(['sql'], 'drizzle-orm')],
code: '.default(sql`CURRENT_TIMESTAMP(3)`)',
}
}

// Drizzle doesn't respect the timezone, different on postgres
// Might be caused by https://github.com/drizzle-team/drizzle-orm/issues/1442
if (field.type === 'DateTime') {
return `.$defaultFn(() => new Date('${field.default}'))`
return {
code: `.$defaultFn(() => new Date('${field.default}'))`,
}
}
},
})
Expand Down Expand Up @@ -136,7 +137,9 @@ export const mysqlAdapter = createAdapter({
isDefaultFunc(field) &&
field.default.name === 'autoincrement'
) {
return `.autoincrement()`
return {
code: '.autoincrement()',
}
}
},
})
Expand All @@ -147,7 +150,9 @@ export const mysqlAdapter = createAdapter({
field,
imports: [namedImport(['json'], coreModule)],
func: `json('${getDbName(field)}')`,
onDefault: (field) => `.$defaultFn(() => (${field.default}))`,
onDefault: (field) => ({
code: `.$defaultFn(() => (${field.default}))`,
}),
})
},
// https://orm.drizzle.team/docs/column-types/mysql/#text
Expand Down
4 changes: 3 additions & 1 deletion packages/usage/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@
"@paralleldrive/cuid2": "^2.2.2",
"@prisma/client": "5.8.1",
"better-sqlite3": "^9.3.0",
"decimal.js": "^10.4.3",
"date-fns": "^3.2.0",
"decimal.js": "^10.4.3",
"drizzle-orm": "^0.29.3",
"mysql2": "^3.7.1",
"pg": "^8.11.3",
"postgres": "^3.4.3",
"uuid": "^9.0.1",
"valibot": "^0.26.0",
"vitest": "^1.2.1"
},
"devDependencies": {
"@types/better-sqlite3": "^7.6.8",
"@types/node": "20.11.5",
"@types/pg": "^8.10.9",
"@types/uuid": "^9.0.7",
"bun-types": "^1.0.23",
"prisma": "5.8.1",
"prisma-generator-drizzle": "workspace:*",
Expand Down
1 change: 1 addition & 0 deletions packages/usage/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ model Default {
alsoId String?
/// drizzle.default crypto::randomBytes `() => randomBytes(16).toString('hex')`
salt String?
pgUuid String? @default(dbgenerated("gen_random_uuid()")) // -sqlite -mysql
date DateTime? @default("2024-01-23T00:00:00Z")
int Int? @default(1)
boolean Boolean? @default(true)
Expand Down
5 changes: 5 additions & 0 deletions packages/usage/tests/shared/testDefault.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { createId, isCuid } from '@paralleldrive/cuid2'
import { isAfter, subSeconds } from 'date-fns'
import Decimal from 'decimal.js'
import { eq, inArray } from 'drizzle-orm'
import { validate as validateUuid } from 'uuid'
import { throwIfnull } from 'tests/utils/query'
import { TestContext } from 'tests/utils/types'
import { describe, expect, test } from 'vitest'
Expand Down Expand Up @@ -45,6 +46,10 @@ export function testDefault({ db, schema, provider }: TestContext) {
expect(result.float, 'Invalid float').toBe(1.123)
expect(result.bytes.toString(), 'Invalid bytes').toBe('hello world')

if (provider === 'postgres') {
expect(validateUuid(result.pgUuid)).toBe(true)
}

if (provider !== 'sqlite') {
expect(result.enum, 'Invalid enum').toBe('TypeTwo')
expect(result.json, 'Invalid json').toStrictEqual({ foo: 'bar' })
Expand Down

0 comments on commit f43f345

Please sign in to comment.