Skip to content

Commit

Permalink
treat model validator as an instance method when mode='after' (#779)
Browse files Browse the repository at this point in the history
* treat model validator as an instance method when mode='after'

* Fix unittest

* Fix unittest
  • Loading branch information
koxudaxi authored Aug 5, 2023
1 parent 77c7f48 commit bd665a1
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 28 deletions.
32 changes: 24 additions & 8 deletions src/com/koxudaxi/pydantic/Pydantic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ val V2_VALIDATOR_QUALIFIED_NAMES = listOf(
MODEL_VALIDATOR_SHORT_QUALIFIED_NAME
)

val MODEL_VALIDATOR_QUALIFIED_NAMES = listOf(
MODEL_VALIDATOR_QUALIFIED_NAME,
MODEL_VALIDATOR_SHORT_QUALIFIED_NAME
)
val FIELD_VALIDATOR_Q_NAMES = listOf(
VALIDATOR_Q_NAME,
VALIDATOR_SHORT_Q_NAME,
Expand Down Expand Up @@ -241,13 +245,9 @@ internal fun isSubClassOfCustomBaseModel(pyClass: PyClass, context: TypeEvalCont
internal val PyClass.isBaseSettings: Boolean get() = qualifiedName == BASE_SETTINGS_Q_NAME


internal fun hasDecorator(pyDecoratable: PyDecoratable, refNames: List<QualifiedName>): Boolean {
return pyDecoratable.decoratorList?.decorators?.mapNotNull { it.callee as? PyReferenceExpression }?.any {
PyResolveUtil.resolveImportedElementQNameLocally(it).any { decoratorQualifiedName ->
refNames.any { refName -> decoratorQualifiedName == refName }
}
} ?: false
}
internal fun hasDecorator(pyDecoratable: PyDecoratable, refNames: List<QualifiedName>): Boolean =
pyDecoratable.decoratorList?.decorators?.any {it.include(refNames)} ?: false


internal val PyClass.isPydanticDataclass: Boolean get() = hasDecorator(this, DATA_CLASS_QUALIFIED_NAMES)

Expand All @@ -269,11 +269,27 @@ internal fun isDataclassMissing(pyTargetExpression: PyTargetExpression): Boolean
return pyTargetExpression.qualifiedName == DATACLASS_MISSING
}

internal fun PyFunction.isValidatorMethod(pydanticVersion: KotlinVersion?): Boolean =
internal fun PyFunction.hasValidatorMethod(pydanticVersion: KotlinVersion?): Boolean =
hasDecorator(this, if(pydanticVersion.isV2) V2_VALIDATOR_QUALIFIED_NAMES else VALIDATOR_QUALIFIED_NAMES)

internal fun PyDecorator.include(refNames: List<QualifiedName>): Boolean = (callee as? PyReferenceExpression)?.let {
PyResolveUtil.resolveImportedElementQNameLocally(it).any { decoratorQualifiedName ->
refNames.any { refName -> decoratorQualifiedName == refName }
}
} ?: false

internal val PyKeywordArgument.value: PyExpression?
get() = when (val value = valueExpression) {
is PyReferenceExpression -> (value.reference.resolve() as? PyTargetExpression)?.findAssignedValue()
else -> value
}

internal fun PyFunction.hasModelValidatorModeAfter(): Boolean = decoratorList?.decorators
?.filter { it.include(MODEL_VALIDATOR_QUALIFIED_NAMES) }
?.any { modelValidator ->
modelValidator.argumentList?.getKeywordArgument("mode")
?.let { it.value as? PyStringLiteralExpression }?.stringValue == "after"
} ?: false
internal val PyClass.isConfigClass: Boolean get() = name == "Config"


Expand Down
14 changes: 6 additions & 8 deletions src/com/koxudaxi/pydantic/PydanticIgnoreInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package com.koxudaxi.pydantic

import com.intellij.psi.PsiReference
import com.jetbrains.python.inspections.PyInspectionExtension
import com.jetbrains.python.psi.PyElement
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.PyStringLiteralExpression
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.types.TypeEvalContext

class PydanticIgnoreInspection : PyInspectionExtension() {
Expand All @@ -20,10 +18,10 @@ class PydanticIgnoreInspection : PyInspectionExtension() {
}

override fun ignoreMethodParameters(function: PyFunction, context: TypeEvalContext): Boolean {
return function.containingClass?.let {
isPydanticModel(it,
true,
context) && function.isValidatorMethod(PydanticCacheService.getVersion(function.project))
} == true
val pyClass = function.containingClass ?: return false
if (!isPydanticModel(pyClass, true, context)) return false
if (!function.hasValidatorMethod(PydanticCacheService.getVersion(function.project))) return false
if (function.hasModelValidatorModeAfter()) return false
return true
}
}
11 changes: 4 additions & 7 deletions src/com/koxudaxi/pydantic/PydanticInspection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class PydanticInspection : PyInspection() {
super.visitPyFunction(node)

if (getPydanticModelByAttribute(node, true, myTypeEvalContext) == null) return
if (!node.isValidatorMethod(pydanticCacheService.getOrPutVersion())) return
if (!node.hasValidatorMethod(pydanticCacheService.getOrPutVersion())) return
if (node.hasModelValidatorModeAfter()) return
val paramList = node.parameterList
val params = paramList.parameters
val firstParam = params.firstOrNull()
Expand Down Expand Up @@ -98,13 +99,9 @@ class PydanticInspection : PyInspection() {
private fun inspectValidatorField(pyStringLiteralExpression: PyStringLiteralExpression) {
if (pyStringLiteralExpression.reference?.resolve() != null) return
val pyArgumentList = pyStringLiteralExpression.parent as? PyArgumentList ?: return
pyArgumentList.getKeywordArgument("check_fields")?.let { it ->
val checkFields = when (val value = it.valueExpression){
is PyReferenceExpression -> (value.reference.resolve() as? PyTargetExpression)?.findAssignedValue()
else -> value
}?.let { PyEvaluator.evaluateAsBoolean(it) }
pyArgumentList.getKeywordArgument("check_fields")?.let {
// ignore unresolved value
if (checkFields != true) return
if (PyEvaluator.evaluateAsBoolean(it.value)!= true) return
}
val stringValue = pyStringLiteralExpression.stringValue
if (stringValue == "*") return
Expand Down
4 changes: 2 additions & 2 deletions src/com/koxudaxi/pydantic/PydanticTypeProvider.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class PydanticTypeProvider : PyTypeProviderBase() {
getRefTypeFromFieldName(name, context, pyClass)
}

param.isSelf && func.isValidatorMethod(PydanticCacheService.getVersion(func.project)
) -> {
param.isSelf && func.hasValidatorMethod(PydanticCacheService.getVersion(func.project)) && !func.hasModelValidatorModeAfter()
-> {
val pyClass = func.containingClass ?: return null
if (!isPydanticModel(pyClass, false, context)) return null
context.getType(pyClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import com.jetbrains.python.PythonLanguage
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.impl.PyPsiUtils
import com.jetbrains.python.psi.types.TypeEvalContext
import java.util.regex.Pattern

class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
Expand Down Expand Up @@ -53,7 +52,7 @@ class PydanticTypedValidatorMethodHandler : TypedHandlerDelegate() {
val defNode = maybeDef.node
if (defNode != null && defNode.elementType === PyTokenTypes.DEF_KEYWORD) {
val pyFunction = token.parent as? PyFunction ?: return Result.CONTINUE
if (!pyFunction.isValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE
if (!pyFunction.hasValidatorMethod(PydanticCacheService.getVersion(project))) return Result.CONTINUE
val settings = CodeStyle.getLanguageSettings(file, PythonLanguage.getInstance())
val textToType = StringBuilder()
textToType.append("(")
Expand Down
9 changes: 9 additions & 0 deletions testData/ignoreinspection/validatorModeAfter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel, model_validator


class A(BaseModel):
a: str

@model_validator(mode='after')
def vali<caret>date_a(self):
pass
10 changes: 9 additions & 1 deletion testData/inspectionv2/validatorField.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def validate_c(**kwargs):
def validate_c(**kwargs):
pass

@model_validator('x')
@model_validator(mode='before')
def validate_model_before(cls):
pass

@model_validator(mode='after')
def validate_model_after(self):
pass

@model_validator()
def validate_model(cls):
pass

Expand Down
6 changes: 6 additions & 0 deletions testData/inspectionv2/validatorSelf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ def validate_e<error descr="Method must have a first parameter, usually called '
def validate_model<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass

@model_validator(mode='after')
def validate_model_after(self):
pass

@model_validator(mode='before')
def validate_model_before<error descr="Method must have a first parameter, usually called 'cls'">()</error>:
pass
def dummy(self):
pass

Expand Down
4 changes: 4 additions & 0 deletions testSrc/com/koxudaxi/pydantic/PydanticIgnoreInspectionTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ open class PydanticIgnoreInspectionTest : PydanticTestCase() {
fun testDecoratorField() {
doIgnoreUnresolvedReference(false)
}

fun testValidatorModeAfter() {
doIgnoreMethodParametersTest(false)
}
}

0 comments on commit bd665a1

Please sign in to comment.