diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 63cdec01e57e2c..3ac82064cc8175 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -557,6 +557,14 @@ bool OpenGlInfo::SupportsExplicitFp16() const { return supports_f16_alu && supports_f16_storage; } +bool OpenGlInfo::IsApiOpenGl31OrAbove() const { + return (major_version == 3 && minor_version >= 1) || major_version > 3; +} + +bool OpenGlInfo::IsApiOpenGl32OrAbove() const { + return (major_version == 3 && minor_version >= 2) || major_version > 3; +} + bool VulkanInfo::SupportsExplicitFp16() const { bool supports_f16_alu = false; bool supports_f16_storage = false; @@ -691,6 +699,28 @@ bool GpuInfo::SupportsPointersInKernels() const { return IsApiOpenCl() || IsApiMetal(); } +bool GpuInfo::SupportsZeroClampForImageBuffer() const { + if (IsApiMetal() || IsApiOpenCl()) { + return true; + } else { + return false; + } +} + +bool GpuInfo::SupportsZeroClampForImages() const { + if (IsApiMetal()) { + return true; + } else if (IsApiOpenCl()) { + return true; + } else if (IsApiVulkan()) { + return true; + } else if (IsApiOpenGl()) { + return opengl_info.IsApiOpenGl32OrAbove(); + } else { + return false; + } +} + bool GpuInfo::IsWaveSizeEqualTo32() const { return supported_subgroup_sizes.size() == 1 && supported_subgroup_sizes[0] == 32; @@ -949,8 +979,7 @@ bool GpuInfo::IsApiOpenGl31OrAbove() const { if (!IsApiOpenGl()) { return false; } - return (opengl_info.major_version == 3 && opengl_info.minor_version >= 1) || - opengl_info.major_version > 3; + return opengl_info.IsApiOpenGl31OrAbove(); } bool GpuInfo::IsApiVulkan() const { return gpu_api == GpuApi::kVulkan; } diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index 69dc42d21a95c2..8a5ddc2f2b662b 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -278,6 +278,9 @@ struct OpenGlInfo { int max_compute_work_group_size_z; bool SupportsExplicitFp16() const; + + bool IsApiOpenGl31OrAbove() const; + bool IsApiOpenGl32OrAbove() const; }; struct VulkanInfo { @@ -437,6 +440,9 @@ struct GpuInfo { bool SupportsFloatImage2D(DataType data_type, int channels) const; bool SupportsExtension(const std::string& extension) const; + bool SupportsZeroClampForImageBuffer() const; + bool SupportsZeroClampForImages() const; + int GetComputeUnitsCount() const; int GetMaxImageArguments() const; diff --git a/tensorflow/lite/delegates/gpu/common/gpu_model.cc b/tensorflow/lite/delegates/gpu/common/gpu_model.cc index cce79214ab3aba..b2fc8a44d2c05f 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_model.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_model.cc @@ -197,6 +197,12 @@ absl::Status ReserveGraphTensors(const CreateGpuModelInfo& create_info, const GpuInfo& gpu_info, const GraphFloat32& graph, TensorReserver* tensor_reserver) { + ZeroClampSupport zero_clamp_support; + zero_clamp_support.image_buffer = gpu_info.SupportsZeroClampForImageBuffer(); + zero_clamp_support.image2d = gpu_info.SupportsZeroClampForImages(); + zero_clamp_support.image2d_array = gpu_info.SupportsZeroClampForImages(); + zero_clamp_support.image3d = gpu_info.SupportsZeroClampForImages(); + ValueId max_id = 0; auto tensors = graph.values(); for (auto& t : tensors) { @@ -261,6 +267,7 @@ absl::Status ReserveGraphTensors(const CreateGpuModelInfo& create_info, } } tensor_desc.SetBHWCShape(shape); + tensor_desc.SetZeroClampSupport(zero_clamp_support); tensor_reserver->Add(t->id, tensor_desc); max_id = std::max(max_id, t->id); }