Skip to content

Commit

Permalink
fix(json): support enums in type declarations (#1837)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Nov 7, 2024
1 parent 1d1fec0 commit 1bd1b8f
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 21 deletions.
23 changes: 17 additions & 6 deletions packages/language/src/generated/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ export function isTypeDeclaration(item: unknown): item is TypeDeclaration {
return reflection.isInstance(item, TypeDeclaration);
}

export type TypeDefFieldTypes = Enum | TypeDef;

export const TypeDefFieldTypes = 'TypeDefFieldTypes';

export function isTypeDefFieldTypes(item: unknown): item is TypeDefFieldTypes {
return reflection.isInstance(item, TypeDefFieldTypes);
}

export interface Argument extends AstNode {
readonly $container: InvocationExpr;
readonly $type: 'Argument';
Expand Down Expand Up @@ -654,7 +662,7 @@ export interface TypeDefFieldType extends AstNode {
readonly $type: 'TypeDefFieldType';
array: boolean
optional: boolean
reference?: Reference<TypeDef>
reference?: Reference<TypeDefFieldTypes>
type?: BuiltinType
}

Expand Down Expand Up @@ -738,14 +746,15 @@ export type ZModelAstType = {
TypeDef: TypeDef
TypeDefField: TypeDefField
TypeDefFieldType: TypeDefFieldType
TypeDefFieldTypes: TypeDefFieldTypes
UnaryExpr: UnaryExpr
UnsupportedFieldType: UnsupportedFieldType
}

export class ZModelAstReflection extends AbstractAstReflection {

getAllTypes(): string[] {
return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'UnaryExpr', 'UnsupportedFieldType'];
return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'TypeDefFieldTypes', 'UnaryExpr', 'UnsupportedFieldType'];
}

protected override computeIsSubtype(subtype: string, supertype: string): boolean {
Expand Down Expand Up @@ -775,16 +784,18 @@ export class ZModelAstReflection extends AbstractAstReflection {
case ConfigArrayExpr: {
return this.isSubtype(ConfigExpr, supertype);
}
case DataModel:
case Enum:
case TypeDef: {
case DataModel: {
return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype);
}
case DataModelField:
case EnumField:
case FunctionParam: {
return this.isSubtype(ReferenceTarget, supertype);
}
case Enum:
case TypeDef: {
return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype) || this.isSubtype(TypeDefFieldTypes, supertype);
}
case InvocationExpr:
case LiteralExpr: {
return this.isSubtype(ConfigExpr, supertype) || this.isSubtype(Expression, supertype);
Expand Down Expand Up @@ -821,7 +832,7 @@ export class ZModelAstReflection extends AbstractAstReflection {
return ReferenceTarget;
}
case 'TypeDefFieldType:reference': {
return TypeDef;
return TypeDefFieldTypes;
}
default: {
throw new Error(`${referenceId} is not a valid reference id.`);
Expand Down
31 changes: 26 additions & 5 deletions packages/language/src/generated/grammar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
"$ref": "#/types@1"
"$ref": "#/types@2"
},
"terminal": {
"$type": "RuleCall",
Expand Down Expand Up @@ -2267,7 +2267,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
},
"arguments": []
},
"cardinality": "+"
"cardinality": "*"
},
{
"$type": "Keyword",
Expand Down Expand Up @@ -2375,7 +2375,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
"$ref": "#/rules@40"
"$ref": "#/types@1"
},
"terminal": {
"$type": "RuleCall",
Expand Down Expand Up @@ -2827,7 +2827,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
"$ref": "#/types@1"
"$ref": "#/types@2"
},
"terminal": {
"$type": "RuleCall",
Expand Down Expand Up @@ -3255,7 +3255,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
"$ref": "#/types@1"
"$ref": "#/types@2"
},
"terminal": {
"$type": "RuleCall",
Expand Down Expand Up @@ -3838,6 +3838,27 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
]
}
},
{
"$type": "Type",
"name": "TypeDefFieldTypes",
"type": {
"$type": "UnionType",
"types": [
{
"$type": "SimpleType",
"typeRef": {
"$ref": "#/rules@40"
}
},
{
"$type": "SimpleType",
"typeRef": {
"$ref": "#/rules@44"
}
}
]
}
},
{
"$type": "Type",
"name": "TypeDeclaration",
Expand Down
8 changes: 5 additions & 3 deletions packages/language/src/zmodel.langium
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,17 @@ TypeDef:
(comments+=TRIPLE_SLASH_COMMENT)*
'type' name=RegularID '{' (
fields+=TypeDefField
)+
)*
'}';

type TypeDefFieldTypes = TypeDef | Enum;

TypeDefField:
(comments+=TRIPLE_SLASH_COMMENT)*
(comments+=TRIPLE_SLASH_COMMENT)*
name=RegularIDWithTypeNames type=TypeDefFieldType (attributes+=DataModelFieldAttribute)*;

TypeDefFieldType:
(type=BuiltinType | reference=[TypeDef:RegularID]) (array?='[' ']')? (optional?='?')?;
(type=BuiltinType | reference=[TypeDefFieldTypes:RegularID]) (array?='[' ']')? (optional?='?')?;

UnsupportedFieldType:
'Unsupported' '(' (value=LiteralExpr) ')';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { PluginError } from '@zenstackhq/sdk';
import { BuiltinType, TypeDef, TypeDefFieldType } from '@zenstackhq/sdk/ast';
import { getDataModels, PluginError } from '@zenstackhq/sdk';
import { BuiltinType, Enum, isEnum, TypeDef, TypeDefFieldType } from '@zenstackhq/sdk/ast';
import { SourceFile } from 'ts-morph';
import { match } from 'ts-pattern';
import { name } from '..';
Expand Down Expand Up @@ -36,7 +36,11 @@ function zmodelTypeToTsType(type: TypeDefFieldType) {
if (type.type) {
result = builtinTypeToTsType(type.type);
} else if (type.reference?.ref) {
result = type.reference.ref.name;
if (isEnum(type.reference.ref)) {
result = makeEnumTypeReference(type.reference.ref);
} else {
result = type.reference.ref.name;
}
} else {
throw new PluginError(name, `Unsupported field type: ${type}`);
}
Expand All @@ -61,3 +65,17 @@ function builtinTypeToTsType(type: BuiltinType) {
.with('Json', () => 'unknown')
.exhaustive();
}

function makeEnumTypeReference(enumDecl: Enum) {
const zmodel = enumDecl.$container;
const models = getDataModels(zmodel);

if (models.some((model) => model.fields.some((field) => field.type.reference?.ref === enumDecl))) {
// if the enum is referenced by any data model, Prisma already generates its type,
// we just need to reference it
return enumDecl.name;
} else {
// otherwise, we need to inline the enum
return enumDecl.fields.map((field) => `'${field.name}'`).join(' | ');
}
}
23 changes: 19 additions & 4 deletions packages/schema/src/plugins/zod/transformer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk';
import { DataModel, isDataModel, isTypeDef, type Model } from '@zenstackhq/sdk/ast';
import { DataModel, Enum, isDataModel, isEnum, isTypeDef, type Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import path from 'path';
Expand Down Expand Up @@ -53,6 +53,9 @@ export default class Transformer {
}

async generateEnumSchemas() {
const generated: string[] = [];

// generate for enums in DMMF
for (const enumType of this.enumTypes) {
const name = upperCaseFirst(enumType.name);
const filePath = path.join(Transformer.outputPath, `enums/${name}.schema.ts`);
Expand All @@ -61,14 +64,26 @@ export default class Transformer {
`z.enum(${JSON.stringify(enumType.values)})`
)}`;
this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true }));
generated.push(enumType.name);
}

// enums not referenced by data models are not in DMMF, deal with them separately
const extraEnums = this.zmodel.declarations.filter((d): d is Enum => isEnum(d) && !generated.includes(d.name));
for (const enumDecl of extraEnums) {
const name = upperCaseFirst(enumDecl.name);
const filePath = path.join(Transformer.outputPath, `enums/${name}.schema.ts`);
const content = `/* eslint-disable */\n${this.generateImportZodStatement()}\n${this.generateExportSchemaStatement(
`${name}`,
`z.enum(${JSON.stringify(enumDecl.fields.map((f) => f.name))})`
)}`;
this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true }));
generated.push(enumDecl.name);
}

this.sourceFiles.push(
this.project.createSourceFile(
path.join(Transformer.outputPath, `enums/index.ts`),
this.enumTypes
.map((enumType) => `export * from './${upperCaseFirst(enumType.name)}.schema';`)
.join('\n'),
generated.map((name) => `export * from './${upperCaseFirst(name)}.schema';`).join('\n'),
{ overwrite: true }
)
);
Expand Down
77 changes: 77 additions & 0 deletions tests/integration/tests/enhancements/json/crud.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,83 @@ describe('Json field CRUD', () => {
).toResolveTruthy();
});

it('respects enums used by data models', async () => {
const params = await loadSchema(
`
enum Role {
USER
ADMIN
}
type Profile {
role Role
}
model User {
id Int @id @default(autoincrement())
profile Profile @json
@@allow('all', true)
}
model Foo {
id Int @id @default(autoincrement())
role Role
}
`,
{
provider: 'postgresql',
dbUrl,
}
);

prisma = params.prisma;
const db = params.enhance();

await expect(db.user.create({ data: { profile: { role: 'MANAGER' } } })).toBeRejectedByPolicy();
await expect(db.user.create({ data: { profile: { role: 'ADMIN' } } })).resolves.toMatchObject({
profile: { role: 'ADMIN' },
});
await expect(db.user.findFirst()).resolves.toMatchObject({
profile: { role: 'ADMIN' },
});
});

it('respects enums unused by data models', async () => {
const params = await loadSchema(
`
enum Role {
USER
ADMIN
}
type Profile {
role Role
}
model User {
id Int @id @default(autoincrement())
profile Profile @json
@@allow('all', true)
}
`,
{
provider: 'postgresql',
dbUrl,
}
);

prisma = params.prisma;
const db = params.enhance();

await expect(db.user.create({ data: { profile: { role: 'MANAGER' } } })).toBeRejectedByPolicy();
await expect(db.user.create({ data: { profile: { role: 'ADMIN' } } })).resolves.toMatchObject({
profile: { role: 'ADMIN' },
});
await expect(db.user.findFirst()).resolves.toMatchObject({
profile: { role: 'ADMIN' },
});
});

it('respects @default', async () => {
const params = await loadSchema(
`
Expand Down
Loading

0 comments on commit 1bd1b8f

Please sign in to comment.