diff --git a/econpizza/parser/__init__.py b/econpizza/parser/__init__.py index 68a89e3..417ae9d 100644 --- a/econpizza/parser/__init__.py +++ b/econpizza/parser/__init__.py @@ -148,7 +148,7 @@ def _load_external_functions_file(model, context): module = _load_as_module(model["functions_file"]) def func_or_compiled(func): return isinstance( - func, jaxlib.xla_extension.CompiledFunction) or isfunction(func) + func, jaxlib.xla_extension.PjitFunction) or isfunction(func) for m in getmembers(module, func_or_compiled): context[m[0]] = m[1]