Skip to content

Commit

Permalink
Use c10::variant-based enums for Nonlinearity and FanMode
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#27933

Test Plan: Imported from OSS

Differential Revision: D18009044

Pulled By: yf225

fbshipit-source-id: e88229ee30badf7a699f62af61d1e88debc0dc7d
  • Loading branch information
Will Feng authored and facebook-github-bot committed Oct 19, 2019
1 parent a1e14a6 commit eb4bb00
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 17 deletions.
40 changes: 40 additions & 0 deletions test/cpp/api/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <functional>
#include <vector>

using namespace torch::test;

void check_exact_values(
const std::vector<torch::Tensor>& parameters,
const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
Expand Down Expand Up @@ -127,4 +129,42 @@ TEST(InitTest, CalculateGainWithLeakyRelu) {
TEST(InitTest, CanInitializeCnnWithOrthogonal) {
torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
}

#define NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
{ \
std::stringstream buffer; \
CerrRedirect cerr_redirect(buffer.rdbuf()); \
std::cerr << torch::nn::init::func_name(torch::nn::init::Nonlinearity::enum_name) << std::endl; \
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
}

#define FANMODE_ENUM_LEGACY_WARNING_CHECK(func_name, enum_name, enum_torch_kname) \
{ \
std::stringstream buffer; \
CerrRedirect cerr_redirect(buffer.rdbuf()); \
std::cerr << torch::nn::init::func_name(torch::randn({4, 5}), 0, torch::nn::init::FanMode::enum_name) << std::endl; \
ASSERT_EQ(count_substr_occurrences(buffer.str(), enum_torch_kname), 1); \
}

TEST(InitTest, NonlinearityLegacyEnum) {
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Linear, "torch::kLinear")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv1D, "torch::kConv1D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv2D, "torch::kConv2D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Conv3D, "torch::kConv3D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose1D, "torch::kConvTranspose1D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose2D, "torch::kConvTranspose2D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ConvTranspose3D, "torch::kConvTranspose3D")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Sigmoid, "torch::kSigmoid")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, Tanh, "torch::kTanh")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, ReLU, "torch::kReLU")
NONLINEARITY_ENUM_LEGACY_WARNING_CHECK(calculate_gain, LeakyReLU, "torch::kLeakyReLU")
}

TEST(InitTest, FanModeLegacyEnum) {
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanIn, "torch::kFanIn")
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_normal_, FanOut, "torch::kFanOut")

FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanIn, "torch::kFanIn")
FANMODE_ENUM_LEGACY_WARNING_CHECK(kaiming_uniform_, FanOut, "torch::kFanOut")
}
41 changes: 35 additions & 6 deletions torch/csrc/api/include/torch/nn/init.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#pragma once

#include <c10/util/variant.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/enum.h>
#include <torch/types.h>

namespace torch {
namespace nn {
namespace init {

// This enum class is deprecated and will be removed in 1.5
enum class Nonlinearity {
Linear,
Conv1D,
Expand All @@ -21,16 +24,42 @@ enum class Nonlinearity {
LeakyReLU
};

// This enum class is deprecated and will be removed in 1.5
enum class FanMode { FanIn, FanOut };

using NonlinearityType = c10::variant<
enumtype::kLinear,
enumtype::kConv1D,
enumtype::kConv2D,
enumtype::kConv3D,
enumtype::kConvTranspose1D,
enumtype::kConvTranspose2D,
enumtype::kConvTranspose3D,
enumtype::kSigmoid,
enumtype::kTanh,
enumtype::kReLU,
enumtype::kLeakyReLU,

// Support for this enum class is deprecated and will be removed in 1.5.
Nonlinearity
>;

using FanModeType = c10::variant<
enumtype::kFanIn,
enumtype::kFanOut,

// Support for this enum class is deprecated and will be removed in 1.5.
FanMode
>;

} // namespace init
} // nn

namespace nn {
namespace init {

/// Return the recommended gain value for the given nonlinearity function.
TORCH_API double calculate_gain(Nonlinearity nonlinearity, double param = 0.01);
TORCH_API double calculate_gain(NonlinearityType nonlinearity, double param = 0.01);

/// Fills the given `tensor` with the provided `value` in-place, and returns it.
/// No gradient will be recorded for this operation.
Expand Down Expand Up @@ -83,8 +112,8 @@ TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1);
TORCH_API Tensor kaiming_normal_(
Tensor tensor,
double a = 0,
FanMode mode = torch::nn::init::FanMode::FanIn,
Nonlinearity nonlinearity = torch::nn::init::Nonlinearity::LeakyReLU);
FanModeType mode = torch::kFanIn,
NonlinearityType nonlinearity = torch::kLeakyReLU);

/// Fills the input `Tensor` with values according to the method
/// described in "Delving deep into rectifiers: Surpassing human-level
Expand All @@ -94,8 +123,8 @@ TORCH_API Tensor kaiming_normal_(
TORCH_API Tensor kaiming_uniform_(
Tensor tensor,
double a = 0,
FanMode mode = torch::nn::init::FanMode::FanIn,
Nonlinearity nonlinearity = torch::nn::init::Nonlinearity::LeakyReLU);
FanModeType mode = torch::kFanIn,
NonlinearityType nonlinearity = torch::kLeakyReLU);

/// Fills the input `Tensor` with values according to the method
/// described in "Understanding the difficulty of training deep feedforward
Expand All @@ -116,4 +145,4 @@ TORCH_API Tensor zeros_(Tensor tensor);

} // namespace init
} // namespace nn
} // namespace torch
} // namespace torch
82 changes: 71 additions & 11 deletions torch/csrc/api/src/nn/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,72 @@ struct Fan {
int64_t out;
};

#define COMPUTE_NONLINEARITY_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
case Nonlinearity::name: \
TORCH_WARN( \
"The enum value `torch::nn::init::Nonlinearity::", #name, "` is deprecated and will be removed in 1.5. ", \
"Please use `torch::k", #name, "` instead."); \
return torch::k##name;

#define COMPUTE_FANMODE_ENUM(name) /* NOLINT(cppcoreguidelines-macro-usage) */ \
case FanMode::name: \
TORCH_WARN( \
"The enum value `torch::nn::init::FanMode::", #name, "` is deprecated and will be removed in 1.5. ", \
"Please use `torch::k", #name, "` instead."); \
return torch::k##name;

NonlinearityType _compute_nonlinearity_type(Nonlinearity nonlinearity) {
switch (nonlinearity) {
COMPUTE_NONLINEARITY_ENUM(Linear)
COMPUTE_NONLINEARITY_ENUM(Conv1D)
COMPUTE_NONLINEARITY_ENUM(Conv2D)
COMPUTE_NONLINEARITY_ENUM(Conv3D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose1D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose2D)
COMPUTE_NONLINEARITY_ENUM(ConvTranspose3D)
COMPUTE_NONLINEARITY_ENUM(Sigmoid)
COMPUTE_NONLINEARITY_ENUM(Tanh)
COMPUTE_NONLINEARITY_ENUM(ReLU)
COMPUTE_NONLINEARITY_ENUM(LeakyReLU)
default:
TORCH_INTERNAL_ASSERT(
false,
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
"please don't add any new enum to it. ",
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
"and use `torch::kEnumName` to reference it.")
}
}

FanModeType _compute_fanmode_type(FanMode fanmode) {
switch (fanmode) {
COMPUTE_FANMODE_ENUM(FanIn);
COMPUTE_FANMODE_ENUM(FanOut);
default:
TORCH_INTERNAL_ASSERT(
false,
"The enum class `torch::nn::init::Nonlinearity` is deprecated, ",
"please don't add any new enum to it. ",
"Instead, add the new enum to `torch/csrc/api/include/torch/enum.h` ",
"and use `torch::kEnumName` to reference it.")
}
}

double calculate_kaiming_std(
Tensor tensor,
double a,
FanMode mode,
Nonlinearity nonlinearity) {
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;
Fan fan(tensor);
const auto gain = calculate_gain(nonlinearity, a);
double std = 0.0;
if (mode == torch::nn::init::FanMode::FanIn) {

// Support for `torch::nn::init::FanMode` is deprecated and will be removed in 1.5.
if (c10::get_if<FanMode>(&mode)) {
mode = _compute_fanmode_type(c10::get<FanMode>(mode));
}
if (c10::get_if<enumtype::kFanIn>(&mode)) {
std = gain / std::sqrt(fan.in);
} else {
std = gain / std::sqrt(fan.out);
Expand All @@ -53,12 +109,16 @@ double calculate_kaiming_std(
}
} // namespace

double calculate_gain(Nonlinearity nonlinearity, double param) {
if (nonlinearity == torch::nn::init::Nonlinearity::Tanh) {
double calculate_gain(NonlinearityType nonlinearity, double param) {
// Support for `torch::nn::init::Nonlinearity` is deprecated and will be removed in 1.5.
if (c10::get_if<Nonlinearity>(&nonlinearity)) {
nonlinearity = _compute_nonlinearity_type(c10::get<Nonlinearity>(nonlinearity));
}
if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
return 5.0 / 3.0; // NOLINT
} else if (nonlinearity == torch::nn::init::Nonlinearity::ReLU) {
} else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
return std::sqrt(2.0); // NOLINT
} else if (nonlinearity == torch::nn::init::Nonlinearity::LeakyReLU) {
} else if (c10::get_if<enumtype::kLeakyReLU>(&nonlinearity)) {
return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT
}

Expand Down Expand Up @@ -178,8 +238,8 @@ Tensor uniform_(Tensor tensor, double low, double high) {
Tensor kaiming_uniform_(
Tensor tensor,
double a,
FanMode mode,
Nonlinearity nonlinearity) {
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;
auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
// Calculate uniform bounds from standard deviation
Expand All @@ -190,8 +250,8 @@ Tensor kaiming_uniform_(
Tensor kaiming_normal_(
Tensor tensor,
double a,
FanMode mode,
Nonlinearity nonlinearity) {
FanModeType mode,
NonlinearityType nonlinearity) {
NoGradGuard guard;

auto std = calculate_kaiming_std(tensor, a, mode, nonlinearity);
Expand Down

0 comments on commit eb4bb00

Please sign in to comment.