Skip to content

Commit c305956

Browse files
committed
added interface to register loss from outside library
1 parent bd1f727 commit c305956

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

include/tiny-cuda-nn/loss.h

+3
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,7 @@ std::unique_ptr<Loss<T>> default_loss(const std::string& name) {
7070

7171
std::vector<std::string> builtin_losses();
7272

73+
template <typename T>
74+
void register_loss(const std::string& name, const std::function<Loss<T>*(const json&)>& factory);
75+
7376
}

src/loss.cu

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ void register_loss(const std::string& name, const std::function<Loss<T>*(const j
7878
register_loss(loss_factories<T>(), name, factory);
7979
}
8080

81+
template void register_loss<float>(const std::string& name, const std::function<Loss<float>*(const json&)>& factory);
82+
template void register_loss<__half>(const std::string& name, const std::function<Loss<__half>*(const json&)>& factory);
83+
8184
template <typename T>
8285
Loss<T>* create_loss(const json& loss) {
8386
std::string name = loss.value("otype", "RelativeL2");

0 commit comments

Comments
 (0)