diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0711ce7..b6880d9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,6 +13,11 @@ jobs: steps: - name: ⬇️ Checkout repo uses: actions/checkout@v4 + + - name: ⎔ Setup node + uses: actions/setup-node@v4 + with: + node-version: 22 - name: ⎔ Setup bun uses: oven-sh/setup-bun@v1 diff --git a/TESTING.md b/TESTING.md new file mode 100644 index 0000000..d26028f --- /dev/null +++ b/TESTING.md @@ -0,0 +1,11 @@ +## Testing Notes + +> Quick notes on the findings when building the test, not a comprehensive guide yet. + +### On using Vitest + +- Although this project uses bun for the package manager, Vitest is used for testing solely because of its [typechecking capabilities](https://vitest.dev/guide/testing-types). + +- When running Vitest using Bun (`bun vitest run`), there are still `node` calls occurring in the background. Therefore, ensure that you specify `bunx prisma-generator-drizzle` when building the Prisma schema. + +- We're considering a full migration to `pnpm` + `vitest` in the future for the sake of consistency. \ No newline at end of file diff --git a/bun.lockb b/bun.lockb index ece21b8..da95f88 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 14a4ed7..7e0df00 100644 --- a/package.json +++ b/package.json @@ -1,9 +1,10 @@ { "name": "prisma-generator-drizzle-root", - "workspaces": ["packages/*", "examples/*"], "devDependencies": { "@biomejs/biome": "^1.8.2", - "lefthook": "^1.6.16" + "lefthook": "^1.6.16", + "turbo": "^2.0.6" }, - "packageManager": "bun@1.1.10" + "packageManager": "bun@1.1.20", + "workspaces": ["packages/*", "examples/*"] } diff --git a/packages/generator/package.json b/packages/generator/package.json index aeeebd8..972a732 100644 --- a/packages/generator/package.json +++ b/packages/generator/package.json @@ -20,6 +20,7 @@ "test": "bun test" }, "dependencies": { + "@mrleebo/prisma-ast": "^0.12.0", "@prisma/client": "5.10.2", "@prisma/generator-helper": "5.10.2", "@prisma/sdk": "4.0.0", @@ -35,7 +36,7 @@ "@types/prettier": "3.0.0", "prisma": "5.10.2", "tsup": "^8.0.2", - "typescript": "5.4.2" + "typescript": "5.5.3" }, "repository": { "type": "git", diff --git a/packages/generator/src/generator.ts b/packages/generator/src/generator.ts index 716d44c..516dd7d 100644 --- a/packages/generator/src/generator.ts +++ b/packages/generator/src/generator.ts @@ -1,6 +1,7 @@ import { execSync } from 'node:child_process' import fs from 'node:fs' import path from 'node:path' +import { createPrismaSchemaBuilder } from '@mrleebo/prisma-ast' import { type DMMF, type GeneratorOptions, @@ -23,6 +24,7 @@ import type { RelationalModuleSet } from './lib/adapter/modules/relational' import { generateSchemaModules as generateSchemaModule } from './lib/adapter/modules/relational' import type { BaseGeneratedModules } from './lib/adapter/modules/sets/base-generated-modules' import { logger } from './lib/logger' +import { createSchema } from './lib/prisma-helpers/schema/schema' import { type ImportValue, type NamedImport, @@ -46,6 +48,11 @@ generatorHandler({ } }, onGenerate: async (options: GeneratorOptions) => { + const schema = createSchema({ + astSchema: createPrismaSchemaBuilder(options.datamodel).getSchema(), + dmmf: options.dmmf, + }) + initializeGenerator(options) logger.log('Generating drizzle schema...') @@ -55,14 +62,14 @@ generatorHandler({ const modules: GeneratedModules = { extras: adapter.extraModules, enums: generateEnumModules(adapter), - models: generateModelModules(adapter), + models: generateModelModules(adapter, schema), } if (isRelationalQueryEnabled()) { - const relational = generateRelationalModules(modules.models) + const relational = generateRelationalModules(schema, modules.models) modules.relational = relational - const implicit = generateImplicitModules(adapter, relational) + const implicit = generateImplicitModules(adapter, schema, relational) modules.implicitModels = implicit.models modules.implicitRelational = implicit.relational diff --git a/packages/generator/src/lib/adapter/adapter.ts b/packages/generator/src/lib/adapter/adapter.ts index 78ef3f6..b407b84 100644 --- a/packages/generator/src/lib/adapter/adapter.ts +++ b/packages/generator/src/lib/adapter/adapter.ts @@ -1,13 +1,8 @@ -import type { - PrismaEnumField, - PrismaScalarField, -} from '../prisma-helpers/field' +import type { SchemaField } from '../prisma-helpers/schema/schema-field' import type { ImportValue } from '../syntaxes/imports' import type { Module } from '../syntaxes/module' import type { FieldFunc } from './fields/createField' -export type ParsableField = PrismaScalarField | PrismaEnumField - type DeclarationFunc = { imports: ImportValue[]; func: string } export function createAdapter(impl: { @@ -17,16 +12,13 @@ export function createAdapter(impl: { table: (name: string, fields: FieldFunc[]) => DeclarationFunc } fields: Partial< - Record< - PrismaScalarField['type'] | 'enum', - (field: ParsableField) => FieldFunc - > + Record FieldFunc> > extraModules?: Module[] }) { return { ...impl, - parseField(field: ParsableField) { + parseField(field: SchemaField) { const fieldType = field.kind === 'enum' ? 'enum' : field.type const fieldFunc = fieldType in impl.fields ? impl.fields[fieldType] : undefined diff --git a/packages/generator/src/lib/adapter/declarations/generateTableDeclaration.ts b/packages/generator/src/lib/adapter/declarations/generateTableDeclaration.ts index 207f8ee..738b1ae 100644 --- a/packages/generator/src/lib/adapter/declarations/generateTableDeclaration.ts +++ b/packages/generator/src/lib/adapter/declarations/generateTableDeclaration.ts @@ -1,19 +1,20 @@ -import type { DMMF } from '@prisma/generator-helper' import { or } from 'fp-ts/lib/Refinement' import { pipe } from 'fp-ts/lib/function' import { isKind } from '~/lib/prisma-helpers/field' -import { getDbName } from '~/lib/prisma-helpers/getDbName' -import { getModelVarName } from '~/lib/prisma-helpers/model' +import { + type SchemaModel, + getModelFields, +} from '~/lib/prisma-helpers/schema/schema-model' import type { Adapter } from '../types' -export function generateTableDeclaration(adapter: Adapter, model: DMMF.Model) { - const fields = model.fields +export function generateTableDeclaration(adapter: Adapter, model: SchemaModel) { + const fields = getModelFields(model) .filter(pipe(isKind('scalar'), or(isKind('enum')))) .map(adapter.parseField) - const name = getModelVarName(model) + const name = model.getVarName() const tableDeclaration = adapter.getDeclarationFunc.table( - getDbName(model), + model.getDbName(), fields ) diff --git a/packages/generator/src/lib/adapter/declarations/generateTableRelationsDeclaration.ts b/packages/generator/src/lib/adapter/declarations/generateTableRelationsDeclaration.ts index e190fba..fc9662b 100644 --- a/packages/generator/src/lib/adapter/declarations/generateTableRelationsDeclaration.ts +++ b/packages/generator/src/lib/adapter/declarations/generateTableRelationsDeclaration.ts @@ -3,19 +3,22 @@ import { map } from 'fp-ts/lib/Array' import { pipe } from 'fp-ts/lib/function' import { camelCase, kebabCase } from 'lodash' import pluralize from 'pluralize' -import { - type PrismaRelationField, - isRelationField, -} from '~/lib/prisma-helpers/field' -import { getDbName } from '~/lib/prisma-helpers/getDbName' import { getModelVarName } from '~/lib/prisma-helpers/model' +import type { Schema } from '~/lib/prisma-helpers/schema/schema' +import type { SchemaFieldRelational } from '~/lib/prisma-helpers/schema/schema-field' +import { + type SchemaModel, + createSchemaModel, + findCorrespondingAstModel, + getModelFields, +} from '~/lib/prisma-helpers/schema/schema-model' import { namedImport } from '../../syntaxes/imports' import type { ModelModule } from '../modules/model' type GenerateTableRelationsInput = { - fields: PrismaRelationField[] + fields: Array modelModule: ModelModule - datamodel: DMMF.Datamodel + schema: Schema } export function generateTableRelationsDeclaration( @@ -40,7 +43,7 @@ export function generateTableRelationsDeclaration( } function getRelationField(ctx: GenerateTableRelationsInput) { - return (field: PrismaRelationField) => { + return (field: SchemaFieldRelational) => { const { implicit, opts, referenceModelVarName } = !field.isList ? getOneToOneOrManyRelation(field, ctx) : opposingIsList(field, ctx) @@ -64,16 +67,21 @@ function getRelationField(ctx: GenerateTableRelationsInput) { } class DetermineRelationshipError extends Error { - constructor(field: DMMF.Field, message: string) { + constructor(field: SchemaFieldRelational, message: string) { super(`Cannot determine relationship ${field.relationName}, ${message}`) } } function getManyToManyRelation( - field: PrismaRelationField, + field: SchemaFieldRelational, ctx: GenerateTableRelationsInput ) { - const opposingModel = findOpposingRelationModel(field, ctx.datamodel) + if (field.relationName == null) + throw new Error( + `relationName is null for ${field.name} of ${field.model.getDbName()}` + ) + + const opposingModel = findOpposingRelationModel(field, ctx.schema) const joinTable = createImplicitJoinTable(ctx, field.relationName, [ ctx.modelModule.model, opposingModel, @@ -98,17 +106,17 @@ function createRelation(input: { } function holdsForeignKey(args: { - field: PrismaRelationField - model: DMMF.Model + field: SchemaFieldRelational + model: SchemaModel }) { const { field, model } = args - return model.fields.some((f) => + return getModelFields(model).some((f) => field.relationFromFields.some((from) => f.name === from) ) } function getOneToOneOrManyRelation( - field: PrismaRelationField, + field: SchemaFieldRelational, ctx: GenerateTableRelationsInput ) { if (hasReference(field)) { @@ -118,7 +126,7 @@ function getOneToOneOrManyRelation( ? createRelationOpts({ relationName: field.relationName, from: { - modelVarName: getModelVarName(ctx.modelModule.model), + modelVarName: ctx.modelModule.model.getVarName(), fieldNames: field.relationFromFields, }, to: { @@ -132,7 +140,7 @@ function getOneToOneOrManyRelation( // For disambiguating relation - const opposingModel = findOpposingRelationModel(field, ctx.datamodel) + const opposingModel = findOpposingRelationModel(field, ctx.schema) const opposingField = findOpposingRelationField(field, opposingModel) return createRelation({ @@ -148,7 +156,7 @@ function getOneToOneOrManyRelation( ? createRelationOpts({ relationName: field.relationName, from: { - modelVarName: getModelVarName(ctx.modelModule.model), + modelVarName: ctx.modelModule.model.getVarName(), fieldNames: opposingField.relationToFields, }, to: { @@ -160,7 +168,7 @@ function getOneToOneOrManyRelation( }) } -function getManyToOneRelation(field: PrismaRelationField) { +function getManyToOneRelation(field: SchemaFieldRelational) { const opts = createRelationOpts({ relationName: field.relationName }) return createRelation({ referenceModelVarName: getModelVarName(field.type), @@ -210,9 +218,9 @@ function createRelationOpts(input: { function createImplicitJoinTable( ctx: GenerateTableRelationsInput, baseName: string, - models: [DMMF.Model, DMMF.Model] + models: [SchemaModel, SchemaModel] ) { - const pair = models.map(getDbName).sort() + const pair = models.map((model) => model.getDbName()).sort() // Custom varName following drizzle's convention const name = pipe(pair, map(pluralize), (names) => names.join('To')) @@ -247,7 +255,7 @@ function createImplicitJoinTable( type: pair[0], // relationName: `${baseName}_A`, relationFromFields: ['A'], - relationToFields: [findModelPrimaryKey(ctx.datamodel, pair[0]).name], + relationToFields: [findModelPrimaryKey(ctx.schema, pair[0]).name], isGenerated: false, isUpdatedAt: false, }, @@ -276,7 +284,7 @@ function createImplicitJoinTable( type: pair[1], // relationName: `${baseName}_B`, relationFromFields: ['B'], - relationToFields: [findModelPrimaryKey(ctx.datamodel, pair[1]).name], + relationToFields: [findModelPrimaryKey(ctx.schema, pair[1]).name], isGenerated: false, isUpdatedAt: false, }, @@ -290,8 +298,10 @@ function createImplicitJoinTable( return { varName, baseName, model, pair } } -function findModelPrimaryKey(datamodel: DMMF.Datamodel, modelName: string) { - const model = datamodel.models.find((model) => model.name === modelName) +function findModelPrimaryKey(schema: Schema, modelName: string) { + const model = schema.dmmf.datamodel.models.find( + (model) => model.name === modelName + ) if (model == null) throw new Error(`Model ${modelName} not found`) const pkField = model.fields.find((field) => field.isId) if (pkField == null) @@ -300,22 +310,29 @@ function findModelPrimaryKey(datamodel: DMMF.Datamodel, modelName: string) { } function findOpposingRelationModel( - field: PrismaRelationField, - datamodel: DMMF.Datamodel + field: SchemaFieldRelational, + schema: Schema ) { - const opposingModel = datamodel.models.find((m) => m.name === field.type) - if (opposingModel) return opposingModel + const opposingModel = schema.dmmf.datamodel.models.find( + (m) => m.name === field.type + ) + if (opposingModel) { + return createSchemaModel({ + astModel: findCorrespondingAstModel(schema.ast, opposingModel), + dmmfModel: opposingModel, + }) + } throw new DetermineRelationshipError(field, `model ${field.type} not found`) } function findOpposingRelationField( - field: PrismaRelationField, - opposingModel: DMMF.Model + field: SchemaFieldRelational, + opposingModel: SchemaModel ) { - const opposingField = opposingModel.fields.find( - (f) => f.relationName === field.relationName && isRelationField(f) + const opposingField = getModelFields(opposingModel).find( + (f) => f.isRelationField && f.relationName === field.relationName ) - if (opposingField) return opposingField as PrismaRelationField + if (opposingField) return opposingField as SchemaFieldRelational throw new DetermineRelationshipError( field, `field with relation ${field.relationName} not found` @@ -326,29 +343,29 @@ function findOpposingRelationField( * Not a derived relation in which the model holds the reference. * Can be one-to-one or one-to-many */ -function hasReference(field: PrismaRelationField) { +function hasReference(field: SchemaFieldRelational) { return ( field.relationFromFields.length > 0 && field.relationToFields.length > 0 ) } function opposingIsList( - field: PrismaRelationField, + field: SchemaFieldRelational, ctx: GenerateTableRelationsInput ) { - const opposingModel = findOpposingRelationModel(field, ctx.datamodel) + const opposingModel = findOpposingRelationModel(field, ctx.schema) return findOpposingRelationField(field, opposingModel).isList } function hasMultipleDisambiguatingRelations(args: { - field: PrismaRelationField - model: DMMF.Model + field: SchemaFieldRelational + model: SchemaModel }): boolean { let count = 0 - for (const field of args.model.fields) { + for (const field of getModelFields(args.model)) { if ( field.type === args.field.type && - isRelationField(field) && + field.isRelationField && !hasReference(field) ) { count++ diff --git a/packages/generator/src/lib/adapter/fields/createField.ts b/packages/generator/src/lib/adapter/fields/createField.ts index 15fc336..67cce85 100644 --- a/packages/generator/src/lib/adapter/fields/createField.ts +++ b/packages/generator/src/lib/adapter/fields/createField.ts @@ -1,11 +1,15 @@ import type { DMMF } from '@prisma/generator-helper' import { getDirective } from '~/lib/directive' +import type { + SchemaField, + SchemaFieldWithDefault, +} from '~/lib/prisma-helpers/schema/schema-field' import { type ImportValue, defaultImportValue, namedImport, } from '~/lib/syntaxes/imports' -import type { MakeRequired, ModifyType, Prettify } from '~/lib/types/utils' +import type { ModifyType, Prettify } from '~/lib/types/utils' import { getCustomDirective } from './directives/custom' export type DefineImport = { @@ -14,7 +18,7 @@ export type DefineImport = { } export interface CreateFieldInput { - field: DMMF.Field + field: SchemaField imports?: ImportValue[] func: | string @@ -22,7 +26,7 @@ export interface CreateFieldInput { onDefault?: ( field: FieldWithDefault ) => { code: string; imports?: ImportValue[] } | undefined - onPrimaryKey?: (field: DMMF.Field) => string | undefined + onPrimaryKey?: (field: SchemaField) => string | undefined } export type FieldFunc = ReturnType @@ -90,7 +94,7 @@ export function createField(input: CreateFieldInput) { } } -function getCustomType(field: DMMF.Field) { +function getCustomType(field: SchemaField) { const directive = getDirective(field, 'drizzle.type') if (directive == null) return @@ -107,7 +111,7 @@ function getCustomType(field: DMMF.Field) { } } -function getCustomDefault(field: DMMF.Field) { +function getCustomDefault(field: SchemaField) { const directive = getDirective(field, 'drizzle.default') if (directive == null) return @@ -138,11 +142,11 @@ function getCustomDefault(field: DMMF.Field) { } // #region onDefault -export type FieldWithDefault = Prettify> - -export function hasDefault(field: DMMF.Field): field is FieldWithDefault { - return field.hasDefaultValue -} +export type FieldWithDefault = ModifyType< + SchemaFieldWithDefault, + 'isRelationField', + boolean | undefined // Remove this after removing `DMMF` usages +> function isDefaultScalar( field: FieldWithDefault diff --git a/packages/generator/src/lib/adapter/fields/directives/custom.ts b/packages/generator/src/lib/adapter/fields/directives/custom.ts index 12b3e41..feafca3 100644 --- a/packages/generator/src/lib/adapter/fields/directives/custom.ts +++ b/packages/generator/src/lib/adapter/fields/directives/custom.ts @@ -1,10 +1,10 @@ -import type { DMMF } from '@prisma/generator-helper' import * as v from 'valibot' import getErrorMessage from '~/lib/error-message' +import type { SchemaField } from '~/lib/prisma-helpers/schema/schema-field' const DIRECTIVE = 'drizzle.custom' -export function getCustomDirective(field: DMMF.Field) { +export function getCustomDirective(field: SchemaField) { const directiveInput = field.documentation if (directiveInput == null || !directiveInput.startsWith(DIRECTIVE)) { return diff --git a/packages/generator/src/lib/adapter/modules/model.ts b/packages/generator/src/lib/adapter/modules/model.ts index fe85475..9234a96 100644 --- a/packages/generator/src/lib/adapter/modules/model.ts +++ b/packages/generator/src/lib/adapter/modules/model.ts @@ -1,24 +1,33 @@ -import type { DMMF } from '@prisma/generator-helper' -import { getGenerator } from '~/shared/generator-context' -import { getModelModuleName } from '../../prisma-helpers/model' +import type { Schema } from '~/lib/prisma-helpers/schema/schema' +import { + type SchemaModel, + createSchemaModel, + findCorrespondingAstModel, +} from '~/lib/prisma-helpers/schema/schema-model' import { createModule } from '../../syntaxes/module' import { generateTableDeclaration } from '../declarations/generateTableDeclaration' import type { Adapter } from '../types' -export function generateModelModules(adapter: Adapter) { - return getGenerator().dmmf.datamodel.models.map(createModelModule(adapter)) +export function generateModelModules(adapter: Adapter, schema: Schema) { + return schema.dmmf.datamodel.models.map((dmmfModel) => { + return createModelModule( + adapter, + createSchemaModel({ + dmmfModel, + astModel: findCorrespondingAstModel(schema.ast, dmmfModel), + }) + ) + }) } -export function createModelModule(adapter: Adapter) { - return (model: DMMF.Model) => { - const tableVar = generateTableDeclaration(adapter, model) +export function createModelModule(adapter: Adapter, model: SchemaModel) { + const tableVar = generateTableDeclaration(adapter, model) - return createModule({ - name: getModelModuleName(model), - model: model, - tableVar, - declarations: [tableVar], - }) - } + return createModule({ + name: model.getModuleName(), + model: model, + tableVar, + declarations: [tableVar], + }) } -export type ModelModule = ReturnType> +export type ModelModule = ReturnType diff --git a/packages/generator/src/lib/adapter/modules/relational.ts b/packages/generator/src/lib/adapter/modules/relational.ts index 2e33fb4..d941117 100644 --- a/packages/generator/src/lib/adapter/modules/relational.ts +++ b/packages/generator/src/lib/adapter/modules/relational.ts @@ -1,9 +1,12 @@ import type { DMMF } from '@prisma/generator-helper' import { isEmpty } from 'lodash' import { deduplicateModels } from '~/generator' -import { isRelationField } from '~/lib/prisma-helpers/field' +import type { Schema } from '~/lib/prisma-helpers/schema/schema' +import { + createSchemaModel, + getModelFields, +} from '~/lib/prisma-helpers/schema/schema-model' import { type Module, createModule } from '~/lib/syntaxes/module' -import { getGenerator } from '~/shared/generator-context' import { generateSchemaDeclaration } from '../declarations/generateSchemaDeclaration' import { generateTableRelationsDeclaration } from '../declarations/generateTableRelationsDeclaration' import type { Adapter } from '../types' @@ -16,24 +19,32 @@ export type RelationalModuleSet = { implicitRelational: Module[] } -export function generateRelationalModules(modelModules: ModelModule[]) { +export function generateRelationalModules( + schema: Schema, + modelModules: ModelModule[] +) { return modelModules.flatMap((modelModule) => { - const relationalModule = createRelationalModule(modelModule) + const relationalModule = createRelationalModule(schema, modelModule) if (relationalModule == null) return [] return relationalModule }) } -export function createRelationalModule(modelModule: ModelModule) { +export function createRelationalModule( + schema: Schema, + modelModule: ModelModule +) { const { model } = modelModule - const relationalFields = model.fields.filter(isRelationField) + const relationalFields = getModelFields(model).filter( + (field) => field.isRelationField === true + ) if (isEmpty(relationalFields)) return undefined const declaration = generateTableRelationsDeclaration({ fields: relationalFields, modelModule: modelModule, - datamodel: getGenerator().dmmf.datamodel, + schema, }) return createModule({ name: `${modelModule.name}-relations`, @@ -48,15 +59,18 @@ export type RelationalModule = NonNullable< export function generateImplicitModules( adapter: Adapter, + schema: Schema, relationalModules: RelationalModule[] ) { const models = relationalModules .flatMap((module) => module.implicit) .reduce(deduplicateModels, [] as DMMF.Model[]) - .map(createModelModule(adapter)) + .map((dmmfModel) => { + return createModelModule(adapter, createSchemaModel({ dmmfModel })) + }) const relational = models.flatMap((modelModule) => { - const relationalModule = createRelationalModule(modelModule) + const relationalModule = createRelationalModule(schema, modelModule) if (relationalModule == null) return [] return relationalModule }) diff --git a/packages/generator/src/lib/adapter/providers/mysql.ts b/packages/generator/src/lib/adapter/providers/mysql.ts index 984ce1e..b9cdfe9 100644 --- a/packages/generator/src/lib/adapter/providers/mysql.ts +++ b/packages/generator/src/lib/adapter/providers/mysql.ts @@ -4,7 +4,7 @@ import { namedImport } from '~/lib/syntaxes/imports' import { createModule } from '~/lib/syntaxes/module' import { getDateMode } from '~/shared/date-mode' import { createAdapter } from '../adapter' -import { createField, hasDefault, isDefaultFunc } from '../fields/createField' +import { createField, isDefaultFunc } from '../fields/createField' import type { BigIntMode } from '../fields/directives/custom' const coreModule = 'drizzle-orm/mysql-core' @@ -103,7 +103,7 @@ export const mysqlAdapter = createAdapter({ // https://github.com/drizzle-team/drizzle-orm/issues/921 onDefault: (field) => { if ( - hasDefault(field) && + field.hasDefaultValue && isDefaultFunc(field) && field.default.name === 'now' ) { diff --git a/packages/generator/src/lib/adapter/providers/postgres.ts b/packages/generator/src/lib/adapter/providers/postgres.ts index 88b0f73..99cec7a 100644 --- a/packages/generator/src/lib/adapter/providers/postgres.ts +++ b/packages/generator/src/lib/adapter/providers/postgres.ts @@ -7,7 +7,6 @@ import { createAdapter } from '../adapter' import { type CreateFieldInput, createField as baseCreateField, - hasDefault, isDefaultFunc, } from '../fields/createField' import type { BigIntMode } from '../fields/directives/custom' @@ -73,7 +72,7 @@ export const postgresAdapter = createAdapter({ // https://orm.drizzle.team/docs/column-types/pg/#bigint BigInt(field) { const func = - hasDefault(field) && + field.hasDefaultValue && isDefaultFunc(field) && field.default.name === 'autoincrement' ? 'bigserial' @@ -129,7 +128,7 @@ export const postgresAdapter = createAdapter({ // https://orm.drizzle.team/docs/column-types/pg/#integer Int(field) { const func = - hasDefault(field) && + field.hasDefaultValue && isDefaultFunc(field) && field.default.name === 'autoincrement' ? // https://arc.net/l/quote/mpimqrfn diff --git a/packages/generator/src/lib/adapter/providers/sqlite.ts b/packages/generator/src/lib/adapter/providers/sqlite.ts index c97abab..d5b1f02 100644 --- a/packages/generator/src/lib/adapter/providers/sqlite.ts +++ b/packages/generator/src/lib/adapter/providers/sqlite.ts @@ -3,7 +3,7 @@ import { namedImport } from '~/lib/syntaxes/imports' import { createModule } from '~/lib/syntaxes/module' import { getDateMode } from '~/shared/date-mode' import { createAdapter } from '../adapter' -import { createField, hasDefault, isDefaultFunc } from '../fields/createField' +import { createField, isDefaultFunc } from '../fields/createField' const coreModule = 'drizzle-orm/sqlite-core' @@ -139,7 +139,7 @@ export const sqliteAdapter = createAdapter({ func: `integer('${getDbName(field)}', { mode: 'number' })`, onPrimaryKey(field) { if ( - hasDefault(field) && + field.hasDefaultValue && isDefaultFunc(field) && field.default.name === 'autoincrement' ) diff --git a/packages/generator/src/lib/directive.ts b/packages/generator/src/lib/directive.ts index c4773f2..7459b77 100644 --- a/packages/generator/src/lib/directive.ts +++ b/packages/generator/src/lib/directive.ts @@ -1,4 +1,4 @@ -import type { DMMF } from '@prisma/generator-helper' +import type { SchemaField } from './prisma-helpers/schema/schema-field' /** * e.g. @@ -9,7 +9,7 @@ import type { DMMF } from '@prisma/generator-helper' * - Input: drizzle.type viem::Address * - Returns: viem:Address */ -export function getDirective(field: DMMF.Field, directive: string) { +export function getDirective(field: SchemaField, directive: string) { if (field.documentation == null) return return field.documentation diff --git a/packages/generator/src/lib/prisma-helpers/field.ts b/packages/generator/src/lib/prisma-helpers/field.ts index 87ddb65..0c5714b 100644 --- a/packages/generator/src/lib/prisma-helpers/field.ts +++ b/packages/generator/src/lib/prisma-helpers/field.ts @@ -1,4 +1,5 @@ import type { DMMF } from '@prisma/generator-helper' +import type { SchemaField } from './schema/schema-field' export type PrismaFieldType = | 'BigInt' @@ -36,16 +37,15 @@ export interface PrismaRelationField export interface PrismaObjectField extends Omit { kind: 'object' } -type PrismaField = PrismaScalarField | PrismaEnumField | PrismaObjectField -export function isKind(kind: TKind) { - return (field: DMMF.Field): field is Extract => +export function isKind(kind: TKind) { + return ( + field: SchemaField + ): field is Extract => field.kind === kind } -export function isRelationField( - field: DMMF.Field -): field is PrismaRelationField { +export function isRelationField(field: DMMF.Field) { return ( field.kind === 'object' && field.relationFromFields != null && diff --git a/packages/generator/src/lib/prisma-helpers/schema/schema-field.ts b/packages/generator/src/lib/prisma-helpers/schema/schema-field.ts new file mode 100644 index 0000000..f5877cd --- /dev/null +++ b/packages/generator/src/lib/prisma-helpers/schema/schema-field.ts @@ -0,0 +1,82 @@ +import type { Field } from '@mrleebo/prisma-ast' +import type { DMMF } from '@prisma/generator-helper' +import type { SchemaModel } from './schema-model' + +export type SchemaField = ReturnType + +export function createSchemaField(args: { + model: SchemaModel + dmmfField: DMMF.Field + astField?: Field +}) { + const { model, dmmfField } = args + + const field = { + model, + isRelationField: undefined as false | undefined, + name: dmmfField.name, + isList: dmmfField.isList, + ...(() => { + if (dmmfField.default != null) { + return { + default: dmmfField.default, + hasDefaultValue: true, + } as const + } + return { + hasDefaultValue: false, + } as const + })(), + kind: dmmfField.kind, + type: dmmfField.type, + documentation: dmmfField.documentation, + isId: dmmfField.isId, + isRequired: dmmfField.isRequired, + getDbName() { + return dmmfField?.dbName ?? dmmfField?.name + }, + } as const + + if ( + field.kind === 'object' && + dmmfField.relationFromFields != null && + dmmfField.relationToFields != null + ) { + return { + ...field, + kind: field.kind, + isRelationField: true, + relationName: dmmfField.relationName, + relationFromFields: dmmfField.relationFromFields, + relationToFields: dmmfField.relationToFields, + } as const + } + + return field +} + +export type SchemaFieldWithDefault = Extract< + SchemaField, + { hasDefaultValue: true } +> + +export type SchemaFieldRelational = Extract< + SchemaField, + { isRelationField: true } +> + +export function findCorrespondingAstField( + model: SchemaModel, + dmmfField: DMMF.Field +) { + if (model.ast == null) + throw new Error(`Model ${model.dmmf.name} has no corresponding ast model`) + + const astField = model.ast.properties.find( + (prop) => prop.type === 'field' && prop.name === dmmfField.name + ) + if (astField?.type !== 'field') { + throw new Error(`Ast field ${dmmfField.name} not found`) + } + return astField +} diff --git a/packages/generator/src/lib/prisma-helpers/schema/schema-model.ts b/packages/generator/src/lib/prisma-helpers/schema/schema-model.ts new file mode 100644 index 0000000..dd1ab3d --- /dev/null +++ b/packages/generator/src/lib/prisma-helpers/schema/schema-model.ts @@ -0,0 +1,52 @@ +import type { Model, Schema } from '@mrleebo/prisma-ast' +import type { DMMF } from '@prisma/generator-helper' +import { getDbName } from '../getDbName' +import { getModelModuleName, getModelVarName } from '../model' +import { createSchemaField, findCorrespondingAstField } from './schema-field' + +export type SchemaModel = ReturnType + +export function createSchemaModel(args: { + astModel?: Model + dmmfModel: DMMF.Model +}) { + const { astModel, dmmfModel } = args + + return { + ast: astModel, + dmmf: dmmfModel, + getDbName() { + return getDbName(dmmfModel) + }, + getVarName() { + return getModelVarName(dmmfModel) + }, + getModuleName() { + return getModelModuleName(dmmfModel) + }, + } +} + +export function getModelFields(model: SchemaModel) { + return model.dmmf.fields.map((dmmfField) => { + return createSchemaField({ + model, + dmmfField, + astField: model.ast + ? findCorrespondingAstField(model, dmmfField) + : undefined, + }) + }) +} + +export function findCorrespondingAstModel( + astSchema: Schema, + dmmfModel: DMMF.Model +) { + const astModel = astSchema.list.find( + (block) => block.type === 'model' && block.name === dmmfModel.name + ) + if (astModel?.type !== 'model') + throw new Error(`Cannot find corresponding ast model for ${dmmfModel.name}`) + return astModel +} diff --git a/packages/generator/src/lib/prisma-helpers/schema/schema.ts b/packages/generator/src/lib/prisma-helpers/schema/schema.ts new file mode 100644 index 0000000..0b4a649 --- /dev/null +++ b/packages/generator/src/lib/prisma-helpers/schema/schema.ts @@ -0,0 +1,16 @@ +import type { Schema as AstSchema } from '@mrleebo/prisma-ast' +import type { DMMF } from '@prisma/generator-helper' + +export function createSchema(args: { + astSchema: AstSchema + dmmf: { + datamodel: DMMF.Datamodel + } +}) { + return { + ast: args.astSchema, + dmmf: args.dmmf, + } +} + +export type Schema = ReturnType diff --git a/packages/generator/src/shared/date-mode.ts b/packages/generator/src/shared/date-mode.ts index 3c77817..4b74a96 100644 --- a/packages/generator/src/shared/date-mode.ts +++ b/packages/generator/src/shared/date-mode.ts @@ -1,11 +1,11 @@ import { parse, picklist } from 'valibot' -import type { ParsableField } from '~/lib/adapter/adapter' import { getDirective } from '~/lib/directive' +import type { SchemaField } from '~/lib/prisma-helpers/schema/schema-field' import { getGenerator } from './generator-context' export const DateMode = picklist(['string', 'date']) -export function getDateMode(field: ParsableField) { +export function getDateMode(field: SchemaField) { const directive = getDirective(field, 'drizzle.dateMode') if (directive) return parse(DateMode, directive) diff --git a/packages/usage/package.json b/packages/usage/package.json index 2753f9d..4a8fb26 100644 --- a/packages/usage/package.json +++ b/packages/usage/package.json @@ -35,7 +35,7 @@ "bun-types": "^1.0.30", "prisma": "5.15.0", "prisma-generator-drizzle": "workspace:*", - "typescript": "5.4.2", + "typescript": "5.5.3", "vitest": "^1.6.0" } } diff --git a/packages/usage/tests/configure-date-mode.test.ts b/packages/usage/tests/configure-date-mode.test.ts index 7f8a673..6f05342 100644 --- a/packages/usage/tests/configure-date-mode.test.ts +++ b/packages/usage/tests/configure-date-mode.test.ts @@ -19,7 +19,7 @@ test('global config', async () => { } generator drizzle { - provider = "prisma-generator-drizzle" + provider = "bunx prisma-generator-drizzle" dateMode = "string" output = "drizzle.ts" } @@ -48,7 +48,7 @@ test('field-level config', async () => { } generator drizzle { - provider = "prisma-generator-drizzle" + provider = "bunx prisma-generator-drizzle" output = "drizzle.ts" }