Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch, OpenCV] Implicit ArrayRef, Updates for JavaCPP 1.5.10 #1455

Merged
merged 14 commits into from
Jan 14, 2024
Next Next commit
Have methods taking ArrayRef accept Vector
  • Loading branch information
HGuillemet committed Jan 9, 2024
commit cf2136f0314ade4166d7849e2d044795f48966c7
43 changes: 24 additions & 19 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -657,22 +657,25 @@ public void map(InfoMap infoMap) {


//// c10::ArrayRef
/* Transparent cast from variadic java args to ArrayRef is only possible for non-boolean primitives (see mapArrayRef).
* For Pointer subclasses for which a std::vector has been instantiated, we rely on ArrayRef converting constructor from std::vector and add the vector as with otherPointerTypes()
*/
for (ArrayInfo t : new ArrayInfo[]{
new ArrayInfo("Argument").elementTypes("c10::Argument"),
new ArrayInfo("ArgumentDef").elementTypes("c10::detail::infer_schema::ArgumentDef"),
new ArrayInfo("BFloat16") /*.itPointerType("ShortPointer") */.elementTypes("decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::BFloat16>::t)"),
new ArrayInfo("Block").elementTypes("torch::jit::Block*").itPointerType("PointerPointer<Block>"),
new ArrayInfo("Bool").itPointerType("BoolPointer").elementTypes("bool", "decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::Bool>::t)").elementValueType("boolean"),
new ArrayInfo("Byte").itPointerType("BytePointer").elementTypes("jbyte", "int8_t", "uint8_t").elementValueType("byte"),
new ArrayInfo("Dimname").otherCppNames("at::DimnameList").elementTypes("at::Dimname"),
new ArrayInfo("Dimname").otherCppNames("at::DimnameList").elementTypes("at::Dimname").otherPointerTypes("DimnameVector"),
new ArrayInfo("Double").itPointerType("DoublePointer").elementTypes("double"),
new ArrayInfo("DoubleComplex") /*.itPointertype("DoublePointer") */.elementTypes("c10::complex<double>"),
new ArrayInfo("EnumNameValue").elementTypes("c10::EnumNameValue"),
new ArrayInfo("Float").itPointerType("FloatPointer").elementTypes("float").elementValueType("float"),
new ArrayInfo("FloatComplex") /*.itPointerType("FloatPointer") */.elementTypes("c10::complex<float>"),
new ArrayInfo("FuturePtr").elementTypes("c10::intrusive_ptr<c10::ivalue::Future>"),
new ArrayInfo("Half") /*.itPointerType("ShortPointer") */.elementTypes("decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::Half>::t)"),
new ArrayInfo("IValue").elementTypes("c10::IValue", "const at::IValue"),
new ArrayInfo("IValue").elementTypes("c10::IValue", "const at::IValue").otherPointerTypes("IValueVector"),
new ArrayInfo("Int")
.itPointerType("IntPointer")
.elementTypes("jint", "int", "int32_t", "uint32_t")
Expand All @@ -683,23 +686,23 @@ public void map(InfoMap infoMap) {
.itPointerType("LongPointer")
.elementTypes("int64_t", "jlong") // Order is important, since ArrayRef<long> and ArrayRef<long long> are incompatible, even though long == long long. And jlong is long long.
.elementValueType("long"),
new ArrayInfo("LongOptional").elementTypes("c10::optional<int64_t>"),
new ArrayInfo("LongOptional").elementTypes("c10::optional<int64_t>").otherPointerTypes("LongOptionalVector"),
new ArrayInfo("NamedValue").elementTypes("torch::jit::NamedValue"),
new ArrayInfo("Scalar").elementTypes("at::Scalar"),
new ArrayInfo("ScalarType").itPointerType("@Cast(\"c10::ScalarType*\") BytePointer").elementTypes("c10::ScalarType", "at::ScalarType"),
new ArrayInfo("ScalarType").itPointerType("@Cast(\"c10::ScalarType*\") BytePointer").elementTypes("c10::ScalarType", "at::ScalarType").otherPointerTypes("ScalarTypeVector"),
new ArrayInfo("Short").itPointerType("ShortPointer").elementTypes("jshort", "int16_t", "uint16_t").elementValueType("short"),
new ArrayInfo("SizeT").itPointerType("SizeTPointer").elementTypes("size_t").elementValueType("long"),
new ArrayInfo("Stride").elementTypes("c10::Stride"),
new ArrayInfo("String").itPointerType("PointerPointer<BytePointer>" /*"@Cast({\"\", \"std::string*\"}) @StdString BytePointer"*/).elementTypes("std::string"),
new ArrayInfo("Stride").elementTypes("c10::Stride").otherPointerTypes("StrideVector"),
new ArrayInfo("String").itPointerType("PointerPointer<BytePointer>" /*"@Cast({\"\", \"std::string*\"}) @StdString BytePointer"*/).elementTypes("std::string").otherPointerTypes("StringVector"),
new ArrayInfo("SymInt").otherCppNames("c10::SymIntArrayRef").elementTypes("c10::SymInt"),
new ArrayInfo("SymNode").elementTypes("c10::SymNode", "c10::intrusive_ptr<c10::SymNodeImpl>"),
new ArrayInfo("Symbol").elementTypes("c10::Symbol"),
new ArrayInfo("Tensor").otherCppNames("torch::TensorList", "at::TensorList", "at::ITensorListRef").elementTypes("torch::Tensor", "at::Tensor"), // Warning: not a TensorList (List<Tensor>)
new ArrayInfo("Symbol").elementTypes("c10::Symbol").otherPointerTypes("SymbolVector"),
new ArrayInfo("Tensor").otherCppNames("torch::TensorList", "at::TensorList", "at::ITensorListRef").elementTypes("torch::Tensor", "at::Tensor").otherPointerTypes("TensorVector"), // Warning: not a TensorList (List<Tensor>)
new ArrayInfo("TensorArg").elementTypes("torch::TensorArg", "at::TensorArg"),
new ArrayInfo("TensorIndex").elementTypes("at::indexing::TensorIndex"),
new ArrayInfo("TensorOptional").elementTypes("c10::optional<at::Tensor>", "c10::optional<torch::Tensor>", "c10::optional<torch::autograd::Variable>"),
new ArrayInfo("Type").itPointerType("Type.TypePtr").elementTypes("c10::TypePtr", "c10::Type::TypePtr"),
new ArrayInfo("Value").elementTypes("torch::jit::Value*")
new ArrayInfo("TensorIndex").elementTypes("at::indexing::TensorIndex").otherPointerTypes("TensorIndexVector"),
new ArrayInfo("TensorOptional").elementTypes("c10::optional<at::Tensor>", "c10::optional<torch::Tensor>", "c10::optional<torch::autograd::Variable>").otherPointerTypes("TensorOptionalVector"),
new ArrayInfo("Type").itPointerType("Type.TypePtr").elementTypes("c10::TypePtr", "c10::Type::TypePtr").otherPointerTypes("TypeVector"),
new ArrayInfo("Value").elementTypes("torch::jit::Value*").otherPointerTypes("ValueVector")

}) {
t.mapArrayRef(infoMap);
Expand Down Expand Up @@ -762,7 +765,7 @@ public void map(InfoMap infoMap) {
}
// swap is a friend templated function. Parser fails to perform template substitution in this case.
infoMap.put(new Info("c10::impl::ListElementReference::swap<T,Iterator>").skip());
// friendly global setting lost
// friendly global setting lost + full qualification not resolved by parser
infoMap.put(new Info("impl::ptr_to_first_element(const c10::List<c10::IValue>&)").javaNames("ptr_to_first_element").annotations("@Name(\"c10::impl::ptr_to_first_element\")").friendly());


Expand Down Expand Up @@ -2567,15 +2570,17 @@ void mapArrayRef(InfoMap infoMap) {
cppNamesRIterator[n++] = cn + "::const_reverse_iterator";
}

// Use converting constructor from std::vector when it works to allow passing java array literals
boolean noVariadicPointerType =
elementValueType.contains(" ") // No @ByVal
|| elementValueType.equals("boolean"); // ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
// Use converting constructor from std::vector when it works to allow passing java array literals.
// Generator doesn't support passing arrays of Pointers as argument, so elementType must be primitive
// and not boolean, since ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.
boolean variadicPointerType = elementValueType.equals("byte") || elementValueType.equals("short") ||
elementValueType.equals("int") || elementValueType.equals("long") ||
elementValueType.equals("float") || elementValueType.equals("double");

String[] pt = new String[otherPointerTypes.length + (noVariadicPointerType ? 1 : 2)];
String[] pt = new String[otherPointerTypes.length + (variadicPointerType ? 2 : 1)];
pt[0] = baseJavaName + "ArrayRef";
System.arraycopy(otherPointerTypes, 0, pt, 1, otherPointerTypes.length);
if (!noVariadicPointerType)
if (variadicPointerType)
pt[otherPointerTypes.length + 1] = "@Cast({\"" + elementTypes[0] + "*\", \"" + cppNames[0] + "\", \"std::vector<" + elementTypes[0] + ">&\"}) @StdVector(\"" + elementTypes[0] + "\") " + elementValueType + "...";
Info info = new Info(cppNames).pointerTypes(pt);
if (baseJavaName.contains("@Cast")) info.cast();
Expand Down