Skip to content

Commit

Permalink
Moved inline functions into separate translation units
Browse files Browse the repository at this point in the history
Instead of using inline keyword to allow multiple definitions of the same function
in different translation units, introduced elementwise_functions_type_utils.cpp
that defines these functions and a header file to use in other translatioon units.

This should reduce the binary size of the produced object files and simplify the
linker's job reducing the link-time.
  • Loading branch information
oleksandr-pavlyk committed Oct 24, 2023
1 parent 22b04e4 commit fa4924a
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 52 deletions.
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ endif()

set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <pybind11/stl.h>
#include <vector>

#include "elementwise_functions_type_utils.hpp"
#include "simplify_iteration_space.hpp"
#include "utils/memory_overlap.hpp"
#include "utils/offset_utils.hpp"
Expand All @@ -46,56 +47,7 @@ namespace tensor
namespace py_internal
{

namespace
{
inline py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
{
switch (dst_typenum_t) {
case td_ns::typenum_t::BOOL:
return py::dtype("?");
case td_ns::typenum_t::INT8:
return py::dtype("i1");
case td_ns::typenum_t::UINT8:
return py::dtype("u1");
case td_ns::typenum_t::INT16:
return py::dtype("i2");
case td_ns::typenum_t::UINT16:
return py::dtype("u2");
case td_ns::typenum_t::INT32:
return py::dtype("i4");
case td_ns::typenum_t::UINT32:
return py::dtype("u4");
case td_ns::typenum_t::INT64:
return py::dtype("i8");
case td_ns::typenum_t::UINT64:
return py::dtype("u8");
case td_ns::typenum_t::HALF:
return py::dtype("f2");
case td_ns::typenum_t::FLOAT:
return py::dtype("f4");
case td_ns::typenum_t::DOUBLE:
return py::dtype("f8");
case td_ns::typenum_t::CFLOAT:
return py::dtype("c8");
case td_ns::typenum_t::CDOUBLE:
return py::dtype("c16");
default:
throw py::value_error("Unrecognized dst_typeid");
}
}

inline int _result_typeid(int arg_typeid, const int *fn_output_id)
{
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
" is outside of expected bounds.");
}

return fn_output_id[arg_typeid];
}

} // end of anonymous namespace

/*! @brief Template implementing Python API for unary elementwise functions */
template <typename output_typesT,
typename contig_dispatchT,
typename strided_dispatchT>
Expand Down Expand Up @@ -297,6 +249,8 @@ py_unary_ufunc(const dpctl::tensor::usm_ndarray &src,
strided_fn_ev);
}

/*! @brief Template implementing Python API for querying of type support by
* unary elementwise functions */
template <typename output_typesT>
py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
const output_typesT &output_types)
Expand All @@ -312,15 +266,17 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
throw py::value_error(e.what());
}

using dpctl::tensor::py_internal::type_utils::_result_typeid;
int dst_typeid = _result_typeid(src_typeid, output_types);

if (dst_typeid < 0) {
auto res = py::none();
return py::cast<py::object>(res);
}
else {
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;

auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
auto dt = _dtype_from_typenum(dst_typenum_t);

return py::cast<py::object>(dt);
Expand All @@ -338,6 +294,8 @@ bool isEqual(Container const &c, std::initializer_list<T> const &l)
}
} // namespace

/*! @brief Template implementing Python API for binary elementwise
* functions */
template <typename output_typesT,
typename contig_dispatchT,
typename strided_dispatchT,
Expand Down Expand Up @@ -605,6 +563,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
strided_fn_ev);
}

/*! @brief Type querying for binary elementwise functions */
template <typename output_typesT>
py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
const py::dtype &input2_dtype,
Expand Down Expand Up @@ -636,8 +595,9 @@ py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype,
return py::cast<py::object>(res);
}
else {
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;

auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
auto dt = _dtype_from_typenum(dst_typenum_t);

return py::cast<py::object>(dt);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "dpctl4pybind11.hpp"
#include <CL/sycl.hpp>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include "elementwise_functions_type_utils.hpp"
#include "utils/type_dispatch.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

namespace dpctl
{
namespace tensor
{
namespace py_internal
{
namespace type_utils
{

py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
{
switch (dst_typenum_t) {
case td_ns::typenum_t::BOOL:
return py::dtype("?");
case td_ns::typenum_t::INT8:
return py::dtype("i1");
case td_ns::typenum_t::UINT8:
return py::dtype("u1");
case td_ns::typenum_t::INT16:
return py::dtype("i2");
case td_ns::typenum_t::UINT16:
return py::dtype("u2");
case td_ns::typenum_t::INT32:
return py::dtype("i4");
case td_ns::typenum_t::UINT32:
return py::dtype("u4");
case td_ns::typenum_t::INT64:
return py::dtype("i8");
case td_ns::typenum_t::UINT64:
return py::dtype("u8");
case td_ns::typenum_t::HALF:
return py::dtype("f2");
case td_ns::typenum_t::FLOAT:
return py::dtype("f4");
case td_ns::typenum_t::DOUBLE:
return py::dtype("f8");
case td_ns::typenum_t::CFLOAT:
return py::dtype("c8");
case td_ns::typenum_t::CDOUBLE:
return py::dtype("c16");
default:
throw py::value_error("Unrecognized dst_typeid");
}
}

int _result_typeid(int arg_typeid, const int *fn_output_id)
{
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
" is outside of expected bounds.");
}

return fn_output_id[arg_typeid];
}

} // namespace type_utils
} // namespace py_internal
} // namespace tensor
} // namespace dpctl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#pragma once
#include "dpctl4pybind11.hpp"
#include <CL/sycl.hpp>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>

#include "utils/type_dispatch.hpp"

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

namespace dpctl
{
namespace tensor
{
namespace py_internal
{
namespace type_utils
{

/*! @brief Produce dtype from a type number */
extern py::dtype _dtype_from_typenum(td_ns::typenum_t);

/*! @brief Lookup typeid of the result from typeid of
* argument and the mapping table */
extern int _result_typeid(int, const int *);

} // namespace type_utils
} // namespace py_internal
} // namespace tensor
} // namespace dpctl

0 comments on commit fa4924a

Please sign in to comment.