Skip to content

Commit

Permalink
Support Union in TorchScript (pytorch#64234)
Browse files Browse the repository at this point in the history
Summary:
This PR is created to replace pytorch#53180 PR stack, which has all the review discussions. Reason for needing a replacement is due to a messy Sandcastle issue.

Pull Request resolved: pytorch#64234

Reviewed By: gmagogsfm

Differential Revision: D30656444

Pulled By: ansley

fbshipit-source-id: 77536c8bcc88162e2c72636026ca3c16891d669a
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Sep 3, 2021
1 parent 91b926f commit 6831d8e
Show file tree
Hide file tree
Showing 50 changed files with 2,132 additions and 462 deletions.
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,12 @@ is `./build/bin/FILENAME --gtest_filter=TESTSUITE.TESTNAME`, where
`TESTNAME` is the name of the test you'd like to run and `TESTSUITE` is
the suite that test is defined in.

For example, if you wanted to run the test ` MayContainAlias`, which
For example, if you wanted to run the test `MayContainAlias`, which
is part of the test suite `ContainerAliasingTest` in the file
`test/cpp/jit/test_alias_analysis.cpp`, the command would be:

```bash
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.UnionAliasing
./build/bin/test_jit --gtest_filter=ContainerAliasingTest.MayContainAlias
```


Expand Down
172 changes: 119 additions & 53 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;

void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);

struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
Expand Down Expand Up @@ -94,25 +97,84 @@ struct SingleElementType : public Type {
TypePtr elem;
};

struct UnionType;
using UnionTypePtr = std::shared_ptr<UnionType>;
struct TORCH_API UnionType : public Type {
friend struct Type;

static const TypeKind Kind = TypeKind::UnionType;

bool isSubtypeOfExt(const TypePtr& rhs_, std::ostream* why_not) const override;

std::string str() const override;

static UnionTypePtr create(std::vector<TypePtr> reference);

bool operator==(const Type& rhs) const override;

at::ArrayRef<TypePtr> containedTypes() const override {
return types_;
}

// For testing purposes only
at::ArrayRef<TypePtr> getTypes() const {
return types_;
}

TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
return create(contained_types);
}

bool canHoldType(TypePtr type) const;

bool hasFreeVariables() const override {
return has_free_variables_;
}

c10::optional<TypePtr> toOptional() const;

c10::optional<TypePtr> subtractTypeSet(std::vector<TypePtr>& to_subtract) const;

protected:
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool has_free_variables_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<TypePtr> types_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool can_hold_none_;

};

struct OptionalType;
using OptionalTypePtr = std::shared_ptr<OptionalType>;
// This type represents an optional type, for each element type.
// Optional[T] can accept both T and None(nullopt in C++)
// This type represents an optional type. There is one `Optional` for
// each element type. `Optional[T]` can accept both `T` and
// `None`(`c10::nullopt` in C++)
// Subtype hierarchy for Optional:
// 1. Optional[T] <: Optional[R] iff T <: R
// 2. T <: Optional[R] if T <: R
// 3. None <: Optional[T] for all T
struct TORCH_API OptionalType
: public SingleElementType<TypeKind::OptionalType, OptionalType> {
static OptionalTypePtr create(TypePtr element) {
TORCH_INTERNAL_ASSERT(element, "OptionalType requires valid TypePtr");
// Optional is a union of [None, T], so Optional[[Optional[T]]] ->
// Optional[T]
if (auto opt_ptr = element->cast<OptionalType>()) {
return opt_ptr;
}
return OptionalTypePtr(
new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
// - Optional[T] <: Optional[R] iff T <: R
// - T <: Optional[R] if T <: R
// - None <: Optional[T] for all T
// - Optional[T] == Union[T, None] for all T
struct TORCH_API OptionalType : public UnionType {
static OptionalTypePtr create(TypePtr contained) {
return OptionalTypePtr(new OptionalType(std::move(contained)));
}

static const TypeKind Kind = TypeKind::OptionalType;

friend struct Type;

bool operator==(const Type& rhs) const override;

TypePtr getElementType() const {
return contained_;
}

at::ArrayRef<TypePtr> containedTypes() const override {
return contained_;
}

std::string str() const override {
Expand All @@ -127,20 +189,15 @@ struct TORCH_API OptionalType
return create(contained_types[0]);
}

bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
if (Type::isSubtypeOfExt(rhs, why_not)) {
return true;
}
if (auto rhs_ = rhs->cast<OptionalType>()) {
return getElementType()->isSubtypeOfExt(rhs_->getElementType(), why_not);
}
return false;
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;

// common cast Optional[Tensor] for undefined tensor type
static OptionalTypePtr ofTensor();

private:
OptionalType(TypePtr elem) : SingleElementType(elem) {}
explicit OptionalType(TypePtr contained);

TypePtr contained_;

std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
std::stringstream ss;
Expand Down Expand Up @@ -908,7 +965,6 @@ struct TORCH_API RRefType
}
};


struct NamedType;
using NamedTypePtr = std::shared_ptr<NamedType>;
using ConstNamedTypePtr = std::shared_ptr<const NamedType>;
Expand Down Expand Up @@ -1112,7 +1168,6 @@ struct TORCH_API EnumType : public NamedType {
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
};


// the common supertype of all Enums, only used in operator registraion.
// EnumType <: AnyEnumType for all Enums
struct AnyEnumType;
Expand All @@ -1132,7 +1187,6 @@ struct TORCH_API AnyEnumType : public Type {
: Type(TypeKind::AnyEnumType) {}
};


struct NumberType;
using NumberTypePtr = std::shared_ptr<NumberType>;
// This type represents a Python number
Expand All @@ -1141,9 +1195,10 @@ using NumberTypePtr = std::shared_ptr<NumberType>;
// FloatType <: NumberType
// ComplexType <:NumberType
struct TORCH_API NumberType : public Type {
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
bool operator==(const Type& rhs) const override;

bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;

std::string str() const override {
return "Scalar"; // match what PythonArgParser says for clarity
}
Expand Down Expand Up @@ -1172,7 +1227,8 @@ struct TORCH_API FloatType : public NumberType {
return "float";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::FloatType;
// global singleton
Expand All @@ -1196,7 +1252,8 @@ struct TORCH_API ComplexType : public NumberType {
return "complex";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::ComplexType;
// global singleton
Expand All @@ -1220,7 +1277,8 @@ struct TORCH_API IntType : public NumberType {
return "int";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override {
return rhs->kind() == TypeKind::NumberType || NumberType::isSubtypeOfExt(rhs, why_not);
// NOLINTNEXTLINE(bugprone-parent-virtual-call)
return rhs->kind() == TypeKind::NumberType || Type::isSubtypeOfExt(rhs, why_not);
}
static const TypeKind Kind = TypeKind::IntType;
// global singleton
Expand Down Expand Up @@ -1334,12 +1392,8 @@ struct TORCH_API NoneType : public Type {
std::string str() const override {
return "NoneType";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override {
if (rhs->kind() == OptionalType::Kind) {
return true;
}
return Type::isSubtypeOfExt(rhs, why_not);
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override;

static const TypeKind Kind = TypeKind::NoneType;
// global singleton
static NoneTypePtr get();
Expand Down Expand Up @@ -1524,8 +1578,15 @@ TORCH_API std::ostream& operator<<(std::ostream& os, const Stride& s);
// what is the type, ignoring extra size/shape information?
// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)

// xxx: be careful with calls because this can be very slow. If calling this on a graph
// use `EraseShapeInformation` in shape_analysis.h
// `unshapedType` is used to remove Tensor subtypes. We treat all Tensor
// subtypes as simply "Tensor"; we also create a new version of any
// container types in which internal Tensors have undergone the same
// operation. This is used for type comparisons between two Tensor types
// (`unshapedType` means that we don't falsely return `false` for e.g.
// Tensors of different dimensions). It's also used in the alias
// analysis pass.
// Be careful with calls because this can be very slow. If calling this
// on a graph, use `EraseShapeInformation` in shape_analysis.h
inline TypePtr unshapedType(const TypePtr& type) {
if (type->isSubtypeOf(TensorType::get())) {
return TensorType::get();
Expand Down Expand Up @@ -1569,27 +1630,32 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
return *result;
}

// Attempt to find the correct supertype of t1 and t2. If none is found then
// nullopt will be returned if default_to_any is false, and Any will be returned
// if it is true. If t1 == t2, or t1 is a type refinement of t2,
// then t2 will be returned (and vice versa).
// Attempt to find the correct supertype of the two types `t1` and `t2`.
// If no supertype is found, then nullopt will be returned if
// `default_to_union` is false, and `Union[t1, t2]` will be returned
// if it is true. If `t1 == t2`, or `t1` is a type refinement of `t2`,
// then `t2` will be returned (and vice versa).
//
// Two different tensortypes will return dynamic.
// Currently we chose not to support returning a NumberType for a float & int
// input because of a lack of operator support for NumberType.
//
// Currently we chose not to support returning a NumberType for
// two types from the set of {FloatType, IntType, ComplexType}, because
// there is a lack of operator support for NumberType.
//
// If `type_hint` is an `InterfaceType`, then we can use that as a
// potential supertype for `ClassType`s in the list. Otherwise, we have
// no way to find and use some common interface type
TORCH_API c10::optional<TypePtr> unifyTypes(
const TypePtr& t1,
const TypePtr& t2,
bool default_to_any = false,
TypePtr type_hint=nullptr);
bool default_to_union = false,
TypePtr type_hint = nullptr);

TORCH_API c10::optional<TypePtr> unifyTypeList(
at::ArrayRef<TypePtr> elements,
std::ostream& why_not,
bool default_to_any=false,
TypePtr type_hint=nullptr);
bool default_to_union = false,
TypePtr type_hint = nullptr);

namespace detail {
template <typename T>
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/core/jit_type_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace c10 {
_(DictType) \
_(NumberType) \
_(FloatType) \
_(ComplexType) \
_(ComplexType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
Expand All @@ -44,7 +44,8 @@ namespace c10 {
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)
_(AnyClassType) \
_(UnionType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
Expand Down Expand Up @@ -203,7 +204,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
TORCH_INTERNAL_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
Expand Down
Loading

0 comments on commit 6831d8e

Please sign in to comment.