From a61f1a87d322ae50b63419f1f3ad06e50b755d66 Mon Sep 17 00:00:00 2001 From: Walter Gray Date: Thu, 16 Jan 2025 18:12:58 -0800 Subject: [PATCH] Add Python language binding using libnanobind --- core/src/ast/types.rs | 5 + tool/src/lib.rs | 3 + tool/src/python/binding.rs | 56 +++ tool/src/python/formatter.rs | 168 ++++++++ tool/src/python/mod.rs | 389 +++++++++++++++++++ tool/src/python/ty.rs | 401 ++++++++++++++++++++ tool/templates/python/binding.cpp.jinja | 16 + tool/templates/python/c_include.h.jinja | 9 + tool/templates/python/enum_impl.cpp.jinja | 6 + tool/templates/python/method_impl.cpp.jinja | 6 + tool/templates/python/module_impl.cpp.jinja | 3 + tool/templates/python/opaque_impl.cpp.jinja | 12 + tool/templates/python/struct_impl.cpp.jinja | 8 + 13 files changed, 1082 insertions(+) create mode 100644 tool/src/python/binding.rs create mode 100644 tool/src/python/formatter.rs create mode 100644 tool/src/python/mod.rs create mode 100644 tool/src/python/ty.rs create mode 100644 tool/templates/python/binding.cpp.jinja create mode 100644 tool/templates/python/c_include.h.jinja create mode 100644 tool/templates/python/enum_impl.cpp.jinja create mode 100644 tool/templates/python/method_impl.cpp.jinja create mode 100644 tool/templates/python/module_impl.cpp.jinja create mode 100644 tool/templates/python/opaque_impl.cpp.jinja create mode 100644 tool/templates/python/struct_impl.cpp.jinja diff --git a/core/src/ast/types.rs b/core/src/ast/types.rs index 195831712..eb7971a98 100644 --- a/core/src/ast/types.rs +++ b/core/src/ast/types.rs @@ -951,6 +951,11 @@ impl TypeName { if let syn::PathArguments::AngleBracketed(type_args) = &p.path.segments.last().unwrap().arguments { + assert!( + type_args.args.len() > 1, + "Not enough arguments given to Result. Are you using a non-std Result type?" + ); + if let (syn::GenericArgument::Type(ok), syn::GenericArgument::Type(err)) = (&type_args.args[0], &type_args.args[1]) { diff --git a/tool/src/lib.rs b/tool/src/lib.rs index 5dae5355e..c285af791 100644 --- a/tool/src/lib.rs +++ b/tool/src/lib.rs @@ -8,6 +8,7 @@ mod dart; mod demo_gen; mod js; mod kotlin; +mod python; use colored::*; use core::mem; @@ -57,6 +58,7 @@ pub fn gen( demo_gen::attr_support() } "kotlin" => kotlin::attr_support(), + "python" => python::attr_support(), o => panic!("Unknown target: {}", o), }; @@ -73,6 +75,7 @@ pub fn gen( "cpp" => cpp::run(&tcx), "dart" => dart::run(&tcx, docs_url_gen), "js" => js::run(&tcx, docs_url_gen), + "python" => python::run(&tcx), "demo_gen" => { let conf = library_config.map(|c| { let str = std::fs::read_to_string(c) diff --git a/tool/src/python/binding.rs b/tool/src/python/binding.rs new file mode 100644 index 000000000..c96d89ccb --- /dev/null +++ b/tool/src/python/binding.rs @@ -0,0 +1,56 @@ +use askama::Template; +use std::borrow::Cow; +use std::collections::BTreeSet; +use std::fmt::{self}; +use std::string::String; + +/// This abstraction allows us to build up the binding piece by piece without needing +/// to precalculate things like the list of dependent headers or classes +#[derive(Default, Template)] +#[template(path = "python/binding.cpp.jinja", escape = "none")] +pub(super) struct Binding<'a> { + /// The module name for this binding + pub module_name: Cow<'a, str>, + /// A list of includes + /// + /// Example: + /// ```c + /// #include "Foo.h" + /// #include "Bar.h" + /// #include "diplomat_runtime.h" + /// ``` + pub includes: BTreeSet>, + /// The actual meat of the impl: usually will contain a type definition and methods + /// + /// Example: + /// ```c + /// typedef struct Foo { + /// uint8_t field1; + /// bool field2; + /// } Foo; + /// + /// Foo make_foo(uint8_t field1, bool field2); + /// ``` + pub body: String, +} + +impl Binding<'_> { + pub fn new() -> Self { + Binding { + includes: BTreeSet::new(), + ..Default::default() + } + } +} + +impl fmt::Write for Binding<'_> { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.body.write_str(s) + } + fn write_char(&mut self, c: char) -> fmt::Result { + self.body.write_char(c) + } + fn write_fmt(&mut self, args: fmt::Arguments<'_>) -> fmt::Result { + self.body.write_fmt(args) + } +} diff --git a/tool/src/python/formatter.rs b/tool/src/python/formatter.rs new file mode 100644 index 000000000..2f503bba7 --- /dev/null +++ b/tool/src/python/formatter.rs @@ -0,0 +1,168 @@ +//! This module contains functions for formatting types + +use crate::c::{CFormatter, CAPI_NAMESPACE}; +use diplomat_core::hir::{self, StringEncoding, TypeContext, TypeId}; +use std::borrow::Cow; + +/// This type mediates all formatting +/// +/// All identifiers from the HIR should go through here before being formatted +/// into the output: This makes it easy to handle reserved words or add rename support +/// +/// If you find yourself needing an identifier formatted in a context not yet available here, please add a new method +/// +/// This type may be used by other backends attempting to figure out the names +/// of C types and methods. +pub(crate) struct PyFormatter<'tcx> { + pub c: CFormatter<'tcx>, +} + +impl<'tcx> PyFormatter<'tcx> { + pub fn new(tcx: &'tcx TypeContext) -> Self { + Self { c: CFormatter::new(tcx, true) } + } + + /// Resolve and format the nested module names for this type + /// Returns an iterator to the namespaces. Will always have at least one entry + pub fn fmt_namespaces(&self, id: TypeId) -> impl Iterator> { + let resolved = self.c.tcx().resolve_type(id); + resolved.attrs().namespace.as_deref().unwrap_or("m").split("::").map(Cow::Borrowed) + } + + /// Resolve the name of the module to use + pub fn fmt_module(&self, id: TypeId) -> Cow<'tcx, str> { + self.fmt_namespaces(id).last().unwrap() + } + + /// Resolve and format a named type for use in code (without the namespace) + pub fn fmt_type_name_unnamespaced(&self, id: TypeId) -> Cow<'tcx, str> { + let resolved = self.c.tcx().resolve_type(id); + + resolved.attrs().rename.apply(resolved.name().as_str().into()) + } + + /// Resolve and format a named type for use in code + pub fn fmt_type_name(&self, id: TypeId) -> Cow<'tcx, str> { + let resolved = self.c.tcx().resolve_type(id); + let name = resolved.attrs().rename.apply(resolved.name().as_str().into()); + if let Some(ref ns) = resolved.attrs().namespace { + format!("{ns}::{name}").into() + } else { + name + } + } + + /// Resolve and format the name of a type for use in header names + pub fn fmt_decl_header_path(&self, id: TypeId) -> String { + let resolved = self.c.tcx().resolve_type(id); + let type_name = resolved.attrs().rename.apply(resolved.name().as_str().into()); + if let Some(ref ns) = resolved.attrs().namespace { + let ns = ns.replace("::", "/"); + format!("../cpp/{ns}/{type_name}.d.hpp") + } else { + format!("../cpp/{type_name}.d.hpp") + } + } + + /// Resolve and format the name of a type for use in header names + pub fn fmt_impl_file_path(&self, id: TypeId) -> String { + let resolved = self.c.tcx().resolve_type(id); + let type_name = resolved.attrs().rename.apply(resolved.name().as_str().into()); + if let Some(ref ns) = resolved.attrs().namespace { + let ns = ns.replace("::", "/"); + format!("../cpp/{ns}/{type_name}.hpp") + } else { + format!("../cpp/{type_name}.hpp") + } + } + + /// Format a field name or parameter name + // might need splitting in the future if we decide to support renames here + pub fn fmt_param_name<'a>(&self, ident: &'a str) -> Cow<'a, str> { + ident.into() + } + + pub fn fmt_c_type_name(&self, id: TypeId) -> Cow<'tcx, str> { + self.c.fmt_type_name_maybe_namespaced(id.into()) + } + + pub fn fmt_c_ptr<'a>(&self, ident: &'a str, mutability: hir::Mutability) -> Cow<'a, str> { + self.c.fmt_ptr(ident, mutability) + } + + pub fn fmt_optional(&self, ident: &str) -> String { + format!("std::optional<{ident}>") + } + + pub fn fmt_borrowed<'a>(&self, ident: &'a str, mutability: hir::Mutability) -> Cow<'a, str> { + // TODO: Where is the right place to put `const` here? + if mutability.is_mutable() { + format!("{ident}&").into() + } else { + format!("const {ident}&").into() + } + } + + pub fn fmt_move_ref<'a>(&self, ident: &'a str) -> Cow<'a, str> { + format!("{ident}&&").into() + } + + pub fn fmt_optional_borrowed<'a>(&self, ident: &'a str, mutability: hir::Mutability) -> Cow<'a, str> { + self.c.fmt_ptr(ident, mutability) + } + + pub fn fmt_owned<'a>(&self, ident: &'a str) -> Cow<'a, str> { + format!("std::unique_ptr<{ident}>").into() + } + + pub fn fmt_borrowed_slice<'a>(&self, ident: &'a str, mutability: hir::Mutability) -> Cow<'a, str> { + // TODO: This needs to change if an abstraction other than std::span is used + // TODO: Where is the right place to put `const` here? + if mutability.is_mutable() { + format!("diplomat::span<{ident}>").into() + } else { + format!("diplomat::span").into() + } + } + + pub fn fmt_borrowed_str(&self, encoding: StringEncoding) -> Cow<'static, str> { + // TODO: This needs to change if an abstraction other than std::u8string_view is used + match encoding { + StringEncoding::Utf8 | StringEncoding::UnvalidatedUtf8 => "std::string_view".into(), + StringEncoding::UnvalidatedUtf16 => "std::u16string_view".into(), + _ => unreachable!(), + } + } + + pub fn fmt_owned_str(&self) -> Cow<'static, str> { + "std::string".into() + } + + /// Format a method + pub fn fmt_method_name<'a>(&self, method: &'a hir::Method) -> Cow<'a, str> { + let name = method.attrs.rename.apply(method.name.as_str().into()); + + // TODO(#60): handle other keywords + if name == "new" { + "new_".into() + } else if name == "default" { + "default_".into() + } else { + name + } + } + + pub fn namespace_c_method_name(&self, ty: TypeId, name: &str) -> String { + let resolved = self.c.tcx().resolve_type(ty); + if let Some(ref ns) = resolved.attrs().namespace { + format!("{ns}::{CAPI_NAMESPACE}::{name}") + } else { + format!("diplomat::{CAPI_NAMESPACE}::{name}") + } + } + + /// Get the primitive type as a C type + pub fn fmt_primitive_as_c(&self, prim: hir::PrimitiveType) -> Cow<'static, str> { + self.c.fmt_primitive_as_c(prim) + } +} diff --git a/tool/src/python/mod.rs b/tool/src/python/mod.rs new file mode 100644 index 000000000..be4f38086 --- /dev/null +++ b/tool/src/python/mod.rs @@ -0,0 +1,389 @@ +mod binding; +mod formatter; +mod ty; + +use std::{borrow::Cow, collections::HashSet}; + +use crate::{ErrorStore, FileMap}; +use binding::Binding; +use diplomat_core::hir::{self, BackendAttrSupport}; +use formatter::PyFormatter; +use ty::TyGenContext; + +/// Python support using the nanobind c++ library to create a python binding. +/// +/// Support for automated python package building is still outstanding. +/// To build, modify the following in the build.rs for your diplomat library: +/// +/// let py_out_dir = Path::new(&out_dir).join("py"); +/// diplomat_tool::gen( +/// Path::new("src/lib.rs"), +/// "python", +/// &py_out_dir, +/// &DocsUrlGenerator::with_base_urls(None, Default::default()), +/// None, +/// false, +/// ) +/// .expect("Error generating python"); +/// +/// // Run python to obtain the include path & linker +/// let pyconfig_out = Command::new("python") +/// .args(["-c", "from sysconfig import get_path\nprint(get_path(\"include\"))"]) +/// .output() +/// .expect("Error running python"); +/// assert!(pyconfig_out.status.success()); +/// let py_include = String::from_utf8_lossy(&pyconfig_out.stdout); +/// let py_lib = Path::new::(py_include.borrow()).parent().unwrap().join("libs"); +/// +/// // Compile libnanobind +/// let nanobind_dir = build_utils::get_workspace_root().unwrap().join("external").join("nanobind"); +/// cc::Build::new() +/// .cpp(true) +/// .flag("-std:c++17") +/// .opt_level(3) +/// .define("NDEBUG", None) +/// .define("NB_COMPACT_ASSERTIONS", None) +/// .include(nanobind_dir.join("include")) +/// .include(nanobind_dir.join("ext").join("robin_map").join("include")) +/// .include(py_include.trim()) +/// .file(nanobind_dir.join("src").join("nb_combined.cpp")) +/// .compile("nanobind-static"); +/// +/// // Compile our extension +/// let mut build = cc::Build::new(); +/// build +/// .cpp(true) +/// .flag("-std:c++17") +/// .opt_level_str("s") +/// .define("NDEBUG", None) +/// .define("zm_EXPORTS", None) +/// .define("NDEBUG", None) +/// .define("NB_COMPACT_ASSERTIONS", None) +/// // For windows: +/// .define("_WINDLL", None) +/// .define("_MBCS", None) +/// .define("_WINDOWS", None) +/// .link_lib_modifier("+whole-archive") +/// .file(py_out_dir.join("nanobindings.cpp")) +/// .include(nanobind_include_dir) +/// .include(py_include.trim()); +/// build.compile("zm_pyext"); +/// +/// println!("cargo::rustc-link-search=native={}", py_lib.display()); + +pub(crate) fn attr_support() -> BackendAttrSupport { + let mut a = BackendAttrSupport::default(); + + a.namespacing = true; + a.memory_sharing = true; + a.non_exhaustive_structs = false; + a.method_overloading = true; + a.utf8_strings = true; + a.utf16_strings = true; + a.static_slices = true; + + a.constructors = false; // TODO + a.named_constructors = false; + a.fallible_constructors = false; + a.accessors = false; + a.comparators = false; // TODO + a.stringifiers = false; // TODO + a.iterators = false; // TODO + a.iterables = false; // TODO + a.indexing = false; // TODO + a.option = true; + a.callbacks = true; + a.traits = false; + + a +} + +pub(crate) fn run(tcx: &hir::TypeContext) -> (FileMap, ErrorStore) { + let files = FileMap::default(); + let formatter = PyFormatter::new(tcx); + let errors = ErrorStore::default(); + + let nanobind_filepath = "nanobindings.cpp"; + let mut binding = Binding::new(); + let mut submodules = HashSet::>::new(); + for (id, ty) in tcx.all_types() { + if ty.attrs().disable { + // Skip type if disabled + continue; + } + + let _type_name_unnamespaced = formatter.fmt_type_name(id); + let decl_header_path = formatter.fmt_decl_header_path(id); + let impl_file_path = formatter.fmt_impl_file_path(id); + + let mut context = TyGenContext { + formatter: &formatter, + errors: &errors, + c2: crate::c::TyGenContext { + tcx, + formatter: &formatter.c, + errors: &errors, + is_for_cpp: false, + id: id.into(), + decl_header_path: &decl_header_path, + impl_header_path: &impl_file_path, + }, + binding: &mut binding, + submodules: &mut submodules, + generating_struct_fields: false, + }; + + // Assert everything shares the same root namespace. If this becomes too restrictive, we can generate multiple modules maybe? + if let Some(ns) = ty + .attrs() + .namespace + .as_ref() + .and_then(|ns| ns.split("::").next()) + { + if context.binding.module_name.is_empty() { + context.binding.module_name = Cow::from(ns); + } else { + assert_eq!(context.binding.module_name, Cow::from(ns)); + } + } + + context + .binding + .includes + .insert(impl_file_path.clone().into()); + + let guard = errors.set_context_ty(ty.name().as_str().into()); + match ty { + hir::TypeDef::Enum(o) => context.gen_enum_def(o, id), + hir::TypeDef::Opaque(o) => context.gen_opaque_def(o, id), + hir::TypeDef::Struct(s) => context.gen_struct_def(s, id), + hir::TypeDef::OutStruct(s) => context.gen_struct_def(s, id), + _ => unreachable!("unknown AST/HIR variant"), + } + drop(guard); + } + + files.add_file(nanobind_filepath.to_owned(), binding.to_string()); + + (files, errors) +} + +#[cfg(test)] +mod test { + use diplomat_core::{ + ast::{self}, + hir::{self, TypeDef}, + }; + use quote::quote; + use std::borrow::Cow; + use std::collections::HashSet; + + #[test] + fn test_opaque_gen() { + let tokens = quote! { + #[diplomat::bridge] + #[diplomat::attr(auto, namespace = "mylib")] + mod ffi { + + #[diplomat::opaque] + struct OpaqueStruct; + + impl OpaqueStruct { + pub fn new() -> Box { + Box::new(OpaqueStruct{}) + } + + pub fn do_thing() -> bool { + return true; + } + } + } + }; + let item = syn::parse2::(tokens).expect("failed to parse item "); + + let mut attr_validator = hir::BasicAttributeValidator::new("python"); + attr_validator.support = crate::python::attr_support(); + + let tcx = match hir::TypeContext::from_syn(&item, attr_validator) { + Ok(context) => context, + Err(e) => { + for (_cx, err) in e { + eprintln!("Lowering error: {}", err); + } + panic!("Failed to create context") + } + }; + + let (type_id, opaque_def) = match tcx + .all_types() + .next() + .expect("Failed to generate first opaque def") + { + (type_id, TypeDef::Opaque(opaque_def)) => (type_id, opaque_def), + _ => panic!("Failed to find opaque type from AST"), + }; + + let formatter = crate::python::PyFormatter::new(&tcx); + let errors = crate::ErrorStore::default(); + let mut binding = crate::python::Binding::new(); + binding.module_name = std::borrow::Cow::Borrowed("pymod"); + + let decl_header_path = formatter.fmt_decl_header_path(type_id); + let impl_file_path = formatter.fmt_impl_file_path(type_id); + + let mut context = crate::python::TyGenContext { + formatter: &formatter, + errors: &errors, + c: crate::c::TyGenContext { + tcx: &tcx, + formatter: &formatter.c, + errors: &errors, + is_for_cpp: false, + id: type_id.into(), + decl_header_path: decl_header_path.clone().into(), + impl_header_path: impl_file_path.clone().into(), + }, + binding: &mut binding, + generating_struct_fields: false, + submodules: HashSet::>::new(), + }; + + context.gen_opaque_def(opaque_def, type_id); + let generated = binding.to_string(); + insta::assert_snapshot!(generated) + } + + #[test] + fn test_enum_gen() { + let tokens = quote! { + #[diplomat::bridge] + #[diplomat::attr(auto, namespace = "mylib")] + mod ffi { + + #[diplomat::enum_convert(my_thingy::SpeedSetting)] + pub enum SpeedSetting { + Fast, Medium, Slow + } + } + }; + let item = syn::parse2::(tokens).expect("failed to parse item "); + + let mut attr_validator = hir::BasicAttributeValidator::new("python"); + attr_validator.support = crate::python::attr_support(); + + let tcx = match hir::TypeContext::from_syn(&item, attr_validator) { + Ok(context) => context, + Err(e) => { + for (_cx, err) in e { + eprintln!("Lowering error: {}", err); + } + panic!("Failed to create context") + } + }; + + let (type_id, enum_def) = match tcx + .all_types() + .next() + .expect("Failed to generate first opaque def") + { + (type_id, TypeDef::Enum(enum_def)) => (type_id, enum_def), + _ => panic!("Failed to find opaque type from AST"), + }; + + let formatter = crate::python::PyFormatter::new(&tcx); + let errors = crate::ErrorStore::default(); + let mut binding = crate::python::Binding::new(); + binding.module_name = std::borrow::Cow::Borrowed("pymod"); + + let decl_header_path = formatter.fmt_decl_header_path(type_id); + let impl_file_path = formatter.fmt_impl_file_path(type_id); + + let mut context = crate::python::TyGenContext { + formatter: &formatter, + errors: &errors, + c: crate::c::TyGenContext { + tcx: &tcx, + formatter: &formatter.c, + errors: &errors, + is_for_cpp: false, + id: type_id.into(), + decl_header_path: decl_header_path.clone().into(), + impl_header_path: impl_file_path.clone().into(), + }, + binding: &mut binding, + generating_struct_fields: false, + submodules: HashSet::>::new(), + }; + + context.gen_enum_def(enum_def, type_id); + let generated = binding.to_string(); + insta::assert_snapshot!(generated) + } + + #[test] + fn test_struct_gen() { + let tokens = quote! { + #[diplomat::bridge] + #[diplomat::attr(auto, namespace = "mylib")] + mod ffi { + pub struct Thingy { + pub a: bool, + pub b: u8, + pub mut c: f64, + } + } + }; + let item = syn::parse2::(tokens).expect("failed to parse item "); + + let mut attr_validator = hir::BasicAttributeValidator::new("python"); + attr_validator.support = crate::python::attr_support(); + + let tcx = match hir::TypeContext::from_syn(&item, attr_validator) { + Ok(context) => context, + Err(e) => { + for (_cx, err) in e { + eprintln!("Lowering error: {}", err); + } + panic!("Failed to create context") + } + }; + + let (type_id, struct_def) = match tcx + .all_types() + .next() + .expect("Failed to generate first opaque def") + { + (type_id, TypeDef::Struct(struct_def)) => (type_id, struct_def), + _ => panic!("Failed to find opaque type from AST"), + }; + + let formatter = crate::python::PyFormatter::new(&tcx); + let errors = crate::ErrorStore::default(); + let mut binding = crate::python::Binding::new(); + binding.module_name = std::borrow::Cow::Borrowed("pymod"); + + let decl_header_path = formatter.fmt_decl_header_path(type_id); + let impl_file_path = formatter.fmt_impl_file_path(type_id); + + let mut context = crate::python::TyGenContext { + formatter: &formatter, + errors: &errors, + c: crate::c::TyGenContext { + tcx: &tcx, + formatter: &formatter.c, + errors: &errors, + is_for_cpp: false, + id: type_id.into(), + decl_header_path: decl_header_path.clone().into(), + impl_header_path: impl_file_path.clone().into(), + }, + binding: &mut binding, + generating_struct_fields: false, + submodules: HashSet::>::new(), + }; + + context.gen_struct_def(struct_def, type_id); + let generated = binding.to_string(); + insta::assert_snapshot!(generated) + } +} diff --git a/tool/src/python/ty.rs b/tool/src/python/ty.rs new file mode 100644 index 000000000..8bc85947f --- /dev/null +++ b/tool/src/python/ty.rs @@ -0,0 +1,401 @@ +use super::binding::Binding; +use super::PyFormatter; +use crate::c::Header as C2Header; +use crate::c::TyGenContext as C2TyGenContext; +use crate::ErrorStore; +use askama::Template; +use diplomat_core::hir::{ + self, EnumVariant, Mutability, OpaqueOwner, ReturnType, StructPathLike, SuccessType, + TyPosition, Type, TypeId, +}; +use std::borrow::Borrow; +use std::borrow::Cow; +use std::collections::HashSet; + +/// A type name with a corresponding variable name, such as a struct field or a function parameter. +struct NamedType<'a> { + var_name: Cow<'a, str>, + _type_name: Cow<'a, str>, +} + +/// Everything needed for rendering a method. +struct MethodInfo<'a> { + /// HIR of the method being rendered + method: &'a hir::Method, + /// The C++ return type + _return_ty: Cow<'a, str>, + /// The C++ method name + method_name: Cow<'a, str>, + /// The C method name + _abi_name: String, + /// Qualifiers for the function that come before the declaration (like "static") + pre_qualifiers: Vec>, + /// Qualifiers for the function that come after the declaration (like "const") + _post_qualifiers: Vec>, + /// Type declarations for the C++ parameters + _param_decls: Vec>, + /// Parameter validations, such as string checks + _param_validations: Vec, +} + +/// Context for generating a particular type's impl +/// 'tcx refers to the lifetime of the typecontext +/// 'cx refers to the lifetime of the context itself +pub(super) struct TyGenContext<'cx, 'tcx> { + pub formatter: &'cx PyFormatter<'tcx>, + pub errors: &'cx ErrorStore<'tcx, String>, + pub c2: C2TyGenContext<'cx, 'tcx>, + pub binding: &'cx mut Binding<'tcx>, + pub submodules: &'cx mut HashSet>, + /// Are we currently generating struct fields? + pub generating_struct_fields: bool, +} + +impl<'ccx, 'tcx: 'ccx, 'bind> TyGenContext<'ccx, 'tcx> { + /// Checks for & outputs a list of modules with their parents that still need to be defined for this type + /// + pub fn get_module_defs( + &mut self, + id: TypeId, + _docstring: Option<&str>, + ) -> Vec<(Cow<'tcx, str>, Cow<'tcx, str>)> { + let mut namespaces = self.formatter.fmt_namespaces(id); + let mut modules: Vec<(Cow<'_, str>, Cow<'_, str>)> = Default::default(); + + while let Some(parent) = namespaces.next() { + if let Some(module) = namespaces.next() { + if self.submodules.contains(&module) { + continue; + } + self.submodules.insert(module.clone()); + + modules.push((module, parent)); + } + } + modules + } + + /// Adds an enum definition to the current implementation. + /// + /// The enum is defined in C++ using a `class` with a single private field that is the + /// C enum type. This enables us to add methods to the enum and generally make the enum + /// behave more like an upgraded C++ type. We don't use `enum class` because methods + /// cannot be added to it. + pub fn gen_enum_def(&mut self, ty: &'tcx hir::EnumDef, id: TypeId) { + let type_name = self.formatter.fmt_type_name(id); + let ctype = self.formatter.fmt_c_type_name(id); + + let values = ty.variants.iter().collect::>(); + + #[derive(Template)] + #[template(path = "python/enum_impl.cpp.jinja", escape = "none")] + struct ImplTemplate<'a> { + _ty: &'a hir::EnumDef, + _fmt: &'a PyFormatter<'a>, + type_name: &'a str, + _ctype: &'a str, + values: &'a [&'a EnumVariant], + module: &'a str, + modules: Vec<(Cow<'a, str>, Cow<'a, str>)>, + } + + ImplTemplate { + _ty: ty, + _fmt: self.formatter, + type_name: &type_name, + _ctype: &ctype, + values: values.as_slice(), + module: self.formatter.fmt_module(id).borrow(), + modules: self.get_module_defs(id, None), + } + .render_into(self.binding) + .unwrap(); + } + + pub fn gen_opaque_def(&mut self, ty: &'tcx hir::OpaqueDef, id: TypeId) { + let type_name = self.formatter.fmt_type_name(id); + let type_name_unnamespaced = self.formatter.fmt_type_name_unnamespaced(id); + let ctype = self.formatter.fmt_c_type_name(id); + let _dtor_name = self + .formatter + .namespace_c_method_name(id, ty.dtor_abi_name.as_str()); + + let c_header = self.c2.gen_opaque_def(ty); + + let methods = ty + .methods + .iter() + .flat_map(|method| self.gen_method_info(id, method)) + .collect::>(); + + #[derive(Template)] + #[template(path = "python/opaque_impl.cpp.jinja", escape = "none")] + struct ImplTemplate<'a> { + // ty: &'a hir::OpaqueDef, + fmt: &'a PyFormatter<'a>, + type_name: &'a str, + ctype: &'a str, + methods: &'a [MethodInfo<'a>], + modules: Vec<(Cow<'a, str>, Cow<'a, str>)>, + module: Cow<'a, str>, + type_name_unnamespaced: &'a str, + _c_header: C2Header, + } + + ImplTemplate { + // ty, + fmt: self.formatter, + type_name: &type_name, + ctype: &ctype, + methods: methods.as_slice(), + modules: self.get_module_defs(id, None), + module: self.formatter.fmt_module(id), + type_name_unnamespaced: &type_name_unnamespaced, + _c_header: c_header, + } + .render_into(self.binding) + .unwrap(); + } + + pub fn gen_struct_def(&mut self, def: &'tcx hir::StructDef

, id: TypeId) { + let type_name = self.formatter.fmt_type_name(id); + let type_name_unnamespaced = self.formatter.fmt_type_name_unnamespaced(id); + let ctype = self.formatter.fmt_c_type_name(id); + + let c_header = self.c2.gen_struct_def(def); + let _c_impl_header = self.c2.gen_impl(def.into()); + + self.generating_struct_fields = true; + let field_decls = def + .fields + .iter() + .map(|field| self.gen_ty_decl(&field.ty, field.name.as_str())) + .collect::>(); + self.generating_struct_fields = false; + + let methods = def + .methods + .iter() + .flat_map(|method| self.gen_method_info(id, method)) + .collect::>(); + + #[derive(Template)] + #[template(path = "python/struct_impl.cpp.jinja", escape = "none")] + struct ImplTemplate<'a> { + // ty: &'a hir::OpaqueDef, + // fmt: &'a Cpp2Formatter<'a>, + type_name: &'a str, + _ctype: &'a str, + fields: &'a [NamedType<'a>], + methods: &'a [MethodInfo<'a>], + modules: Vec<(Cow<'a, str>, Cow<'a, str>)>, + module: Cow<'a, str>, + type_name_unnamespaced: &'a str, + _c_header: C2Header, + } + + ImplTemplate { + // ty, + // fmt: &self.formatter, + type_name: &type_name, + _ctype: &ctype, + fields: field_decls.as_slice(), + methods: methods.as_slice(), + modules: self.get_module_defs(id, None), + module: self.formatter.fmt_module(id), + type_name_unnamespaced: &type_name_unnamespaced, + _c_header: c_header, + } + .render_into(self.binding) + .unwrap(); + } + + fn gen_method_info( + &mut self, + id: TypeId, + method: &'tcx hir::Method, + ) -> Option> { + if method.attrs.disable { + return None; + } + let _guard = self.errors.set_context_method( + self.c2.tcx.fmt_type_name_diagnostics(id), + method.name.as_str().into(), + ); + let method_name = self.formatter.fmt_method_name(method); + let abi_name = self + .formatter + .namespace_c_method_name(id, method.abi_name.as_str()); + let mut param_decls = Vec::new(); + + let mut returns_utf8_err = false; + + for param in method.params.iter() { + let decls = self.gen_ty_decl(¶m.ty, param.name.as_str()); + param_decls.push(decls); + if matches!( + param.ty, + Type::Slice(hir::Slice::Str(_, hir::StringEncoding::Utf8)) + ) { + returns_utf8_err = true; + } + } + + let mut return_ty = self.gen_cpp_return_type_name(&method.output); + + if returns_utf8_err { + return_ty = "diplomat::result".into(); + }; + + let pre_qualifiers = if method.param_self.is_none() { + vec!["static".into()] + } else { + vec![] + }; + + let post_qualifiers = match &method.param_self { + Some(param_self) if param_self.ty.is_immutably_borrowed() => vec!["const".into()], + Some(_) => vec![], + None => vec![], + }; + + Some(MethodInfo { + method, + _return_ty: return_ty, + method_name, + _abi_name: abi_name, + pre_qualifiers, + _post_qualifiers: post_qualifiers, + _param_decls: param_decls, + _param_validations: Default::default(), + }) + } + + /// Generates C++ code for referencing a particular type with a given name. + fn gen_ty_decl<'a, P: TyPosition>(&mut self, ty: &Type

, var_name: &'a str) -> NamedType<'a> + where + 'ccx: 'a, + { + let var_name = self.formatter.fmt_param_name(var_name); + let type_name = self.gen_type_name(ty); + + NamedType { + var_name, + _type_name: type_name, + } + } + + /// Generates Python code for referencing a particular type. + /// + /// This function adds the necessary type imports to the decl and impl files. + fn gen_type_name(&mut self, ty: &Type

) -> Cow<'ccx, str> { + match *ty { + Type::Primitive(prim) => self.formatter.fmt_primitive_as_c(prim), + Type::Opaque(ref op) => { + let op_id = op.tcx_id.into(); + let type_name = self.formatter.fmt_type_name(op_id); + let _type_name_unnamespaced = self.formatter.fmt_type_name_unnamespaced(op_id); + let def = self.c2.tcx.resolve_type(op_id); + + if def.attrs().disable { + self.errors + .push_error(format!("Found usage of disabled type {type_name}")) + } + let mutability = op.owner.mutability().unwrap_or(hir::Mutability::Mutable); + let ret = match (op.owner.is_owned(), op.is_optional()) { + // unique_ptr is nullable + (true, _) => self.formatter.fmt_owned(&type_name), + (false, true) => self.formatter.fmt_optional_borrowed(&type_name, mutability), + (false, false) => self.formatter.fmt_borrowed(&type_name, mutability), + }; + let ret = ret.into_owned().into(); + + self.binding + .includes + .insert(self.formatter.fmt_impl_file_path(op_id).into()); + ret + } + Type::Struct(ref st) => { + let id = st.id(); + let type_name = self.formatter.fmt_type_name(id); + let _type_name_unnamespaced = self.formatter.fmt_type_name_unnamespaced(id); + let def = self.c2.tcx.resolve_type(id); + if def.attrs().disable { + self.errors + .push_error(format!("Found usage of disabled type {type_name}")) + } + + self.binding + .includes + .insert(self.formatter.fmt_impl_file_path(id).into()); + type_name + } + Type::Enum(ref e) => { + let id = e.tcx_id.into(); + let type_name = self.formatter.fmt_type_name(id); + let _type_name_unnamespaced = self.formatter.fmt_type_name_unnamespaced(id); + let def = self.c2.tcx.resolve_type(id); + if def.attrs().disable { + self.errors + .push_error(format!("Found usage of disabled type {type_name}")) + } + + self.binding + .includes + .insert(self.formatter.fmt_impl_file_path(id).into()); + type_name + } + Type::Slice(hir::Slice::Str(_, encoding)) => self.formatter.fmt_borrowed_str(encoding), + Type::Slice(hir::Slice::Primitive(b, p)) => { + let ret = self.formatter.fmt_primitive_as_c(p); + let ret = self.formatter.fmt_borrowed_slice( + &ret, + b.map(|b| b.mutability).unwrap_or(hir::Mutability::Mutable), + ); + ret.into_owned().into() + } + Type::Slice(hir::Slice::Strs(encoding)) => format!( + "diplomat::span", + self.formatter.fmt_borrowed_str(encoding) + ) + .into(), + Type::DiplomatOption(ref inner) => { + format!("std::optional<{}>", self.gen_type_name(inner)).into() + } + Type::Callback(..) => "".into(), + _ => unreachable!("unknown AST/HIR variant"), + } + } + + /// Generates the C++ type name of a return type. + fn gen_cpp_return_type_name(&mut self, result_ty: &ReturnType) -> Cow<'ccx, str> { + match *result_ty { + ReturnType::Infallible(SuccessType::Unit) => "void".into(), + ReturnType::Infallible(SuccessType::Write) => self.formatter.fmt_owned_str(), + ReturnType::Infallible(SuccessType::OutType(ref o)) => self.gen_type_name(o), + ReturnType::Fallible(ref ok, ref err) => { + let ok_type_name = match ok { + SuccessType::Write => self.formatter.fmt_owned_str(), + SuccessType::Unit => "std::monostate".into(), + SuccessType::OutType(o) => self.gen_type_name(o), + _ => unreachable!("unknown AST/HIR variant"), + }; + let err_type_name = match err { + Some(o) => self.gen_type_name(o), + None => "std::monostate".into(), + }; + format!("diplomat::result<{ok_type_name}, {err_type_name}>").into() + } + ReturnType::Nullable(ref ty) => { + let type_name = match ty { + SuccessType::Write => self.formatter.fmt_owned_str(), + SuccessType::Unit => "std::monostate".into(), + SuccessType::OutType(o) => self.gen_type_name(o), + _ => unreachable!("unknown AST/HIR variant"), + }; + self.formatter.fmt_optional(&type_name).into() + } + _ => unreachable!("unknown AST/HIR variant"), + } + } +} diff --git a/tool/templates/python/binding.cpp.jinja b/tool/templates/python/binding.cpp.jinja new file mode 100644 index 000000000..9096d0b5c --- /dev/null +++ b/tool/templates/python/binding.cpp.jinja @@ -0,0 +1,16 @@ +#include + +{%- for include in includes %} +#include "{{ include }}" +{%- endfor %} +#include + +namespace nb = nanobind; +using namespace nb::literals; + +NB_MODULE({{module_name}}, {{module_name}}) +{ + {%- for line in body.lines() %} + {{ line }} + {%- endfor %} +} diff --git a/tool/templates/python/c_include.h.jinja b/tool/templates/python/c_include.h.jinja new file mode 100644 index 000000000..bc47de785 --- /dev/null +++ b/tool/templates/python/c_include.h.jinja @@ -0,0 +1,9 @@ +{% if let Some(ns) = namespace -%} +namespace {{ns}} { +{% else -%} +namespace diplomat { +{% endif -%} +namespace {{self::CAPI_NAMESPACE}} { + {{ c_header.body|trim|indent(4) }} +} // namespace {{self::CAPI_NAMESPACE}} +} // namespace diff --git a/tool/templates/python/enum_impl.cpp.jinja b/tool/templates/python/enum_impl.cpp.jinja new file mode 100644 index 000000000..122fdf63a --- /dev/null +++ b/tool/templates/python/enum_impl.cpp.jinja @@ -0,0 +1,6 @@ +{% include "module_impl.cpp.jinja" %} +nb::enum_<{{type_name}}::Value>({{module}}_mod, "{{type_name}}") + {%- for v in values %} + .value("{{v.name}}", {{type_name}}::{{v.name}}) + {%- endfor -%}; + \ No newline at end of file diff --git a/tool/templates/python/method_impl.cpp.jinja b/tool/templates/python/method_impl.cpp.jinja new file mode 100644 index 000000000..cd65ffb4a --- /dev/null +++ b/tool/templates/python/method_impl.cpp.jinja @@ -0,0 +1,6 @@ + .def{%- for qualifier in m.pre_qualifiers -%}{%- if qualifier == "static" -%}_static{%- endif -%}{%- endfor -%} + ("{{m.method_name}}", &{{- type_name }}::{{ m.method_name -}} + {%- for param in m.method.params -%} + , "{{param.name}}"_a + {%- endfor -%} + ) \ No newline at end of file diff --git a/tool/templates/python/module_impl.cpp.jinja b/tool/templates/python/module_impl.cpp.jinja new file mode 100644 index 000000000..1b650739e --- /dev/null +++ b/tool/templates/python/module_impl.cpp.jinja @@ -0,0 +1,3 @@ +{%- for (m, p) in modules %} +nb::module_ {{m}}_mod({{p}}); +{%- endfor %} \ No newline at end of file diff --git a/tool/templates/python/opaque_impl.cpp.jinja b/tool/templates/python/opaque_impl.cpp.jinja new file mode 100644 index 000000000..b653c2099 --- /dev/null +++ b/tool/templates/python/opaque_impl.cpp.jinja @@ -0,0 +1,12 @@ +{% let const_ptr = fmt.fmt_c_ptr(type_name, Mutability::Immutable) -%} +{% let mut_ptr = fmt.fmt_c_ptr(type_name, Mutability::Mutable) -%} +{% let const_cptr = fmt.fmt_c_ptr(ctype, Mutability::Immutable) -%} +{% let mut_cptr = fmt.fmt_c_ptr(ctype, Mutability::Mutable) -%} +{% let const_ref = fmt.fmt_borrowed(type_name, Mutability::Immutable) -%} +{% let move_ref = fmt.fmt_move_ref(type_name) -%} + +{% include "module_impl.cpp.jinja" %} +nb::class_<{{type_name}}>({{ module }}_mod, "{{type_name_unnamespaced}}") +{%- for m in methods %} +{% include "method_impl.cpp.jinja" %} +{%- endfor %}; diff --git a/tool/templates/python/struct_impl.cpp.jinja b/tool/templates/python/struct_impl.cpp.jinja new file mode 100644 index 000000000..d0886610c --- /dev/null +++ b/tool/templates/python/struct_impl.cpp.jinja @@ -0,0 +1,8 @@ +{% include "module_impl.cpp.jinja" %} +nb::class_<{{type_name}}>({{module}}_mod, "{{type_name_unnamespaced}}") +{%- for f in fields %} + .def_ro("{{f.var_name}}", &{{type_name}}::{{f.var_name}}) +{%- endfor %} +{%- for m in methods %} +{% include "method_impl.cpp.jinja" %} +{%- endfor %}; \ No newline at end of file