Skip to content

Commit

Permalink
Add scalable vector (#542)
Browse files Browse the repository at this point in the history
* feat: add scalable vector support

* test: add some scalable vector

* feat: add separate scalable vector type and value

Introduce `VectorBaseValue` trait for both fixed and scalable vector values.

Add type, value and builder tests for scalable vector.

Revise min. LLVM version from 11 to 12 due to llvm_sys.

* fix: various bugs

fix: remove always false functions

style: cargo fmt

fix: rename doc comments

fix: remove leftover test

fix: add conditional llvm versions to type kind variant

fix: change to doctest to ignore

fix: remove llvm 11 from scalable vector type test

---------

Co-authored-by: Dan Kolsoi <[email protected]>
  • Loading branch information
my4ng and TheDan64 authored Oct 28, 2024
1 parent 7b41029 commit 8bd9f08
Show file tree
Hide file tree
Showing 16 changed files with 1,388 additions and 43 deletions.
26 changes: 13 additions & 13 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::values::CallableValue;
use crate::values::{
AggregateValue, AggregateValueEnum, AsValueRef, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue,
FloatMathValue, FunctionValue, GlobalValue, InstructionOpcode, InstructionValue, IntMathValue, IntValue, PhiValue,
PointerMathValue, PointerValue, StructValue, VectorValue,
PointerMathValue, PointerValue, StructValue, VectorBaseValue,
};

use crate::{AtomicOrdering, AtomicRMWBinOp, FloatPredicate, IntPredicate};
Expand Down Expand Up @@ -3064,9 +3064,9 @@ impl<'ctx> Builder<'ctx> {
///
/// builder.build_return(Some(&extracted)).unwrap();
/// ```
pub fn build_extract_element(
pub fn build_extract_element<V: VectorBaseValue<'ctx>>(
&self,
vector: VectorValue<'ctx>,
vector: V,
index: IntValue<'ctx>,
name: &str,
) -> Result<BasicValueEnum<'ctx>, BuilderError> {
Expand Down Expand Up @@ -3112,13 +3112,13 @@ impl<'ctx> Builder<'ctx> {
/// builder.build_insert_element(vector_param, i32_seven, i32_zero, "insert").unwrap();
/// builder.build_return(None).unwrap();
/// ```
pub fn build_insert_element<V: BasicValue<'ctx>>(
pub fn build_insert_element<V: BasicValue<'ctx>, W: VectorBaseValue<'ctx>>(
&self,
vector: VectorValue<'ctx>,
vector: W,
element: V,
index: IntValue<'ctx>,
name: &str,
) -> Result<VectorValue<'ctx>, BuilderError> {
) -> Result<W, BuilderError> {
if self.positioned.get() != PositionState::Set {
return Err(BuilderError::UnsetPosition);
}
Expand All @@ -3134,7 +3134,7 @@ impl<'ctx> Builder<'ctx> {
)
};

unsafe { Ok(VectorValue::new(value)) }
unsafe { Ok(W::new(value)) }
}

pub fn build_unreachable(&self) -> Result<InstructionValue<'ctx>, BuilderError> {
Expand Down Expand Up @@ -3329,13 +3329,13 @@ impl<'ctx> Builder<'ctx> {
}

// REVIEW: Do we need to constrain types here? subtypes?
pub fn build_shuffle_vector(
pub fn build_shuffle_vector<V: VectorBaseValue<'ctx>>(
&self,
left: VectorValue<'ctx>,
right: VectorValue<'ctx>,
mask: VectorValue<'ctx>,
left: V,
right: V,
mask: V,
name: &str,
) -> Result<VectorValue<'ctx>, BuilderError> {
) -> Result<V, BuilderError> {
if self.positioned.get() != PositionState::Set {
return Err(BuilderError::UnsetPosition);
}
Expand All @@ -3350,7 +3350,7 @@ impl<'ctx> Builder<'ctx> {
)
};

unsafe { Ok(VectorValue::new(value)) }
unsafe { Ok(V::new(value)) }
}

// REVIEW: Is return type correct?
Expand Down
56 changes: 53 additions & 3 deletions src/types/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use llvm_sys::LLVMTypeKind;
use crate::support::LLVMString;
use crate::types::traits::AsTypeRef;
use crate::types::MetadataType;
use crate::types::{ArrayType, FloatType, FunctionType, IntType, PointerType, StructType, VectorType, VoidType};
use crate::types::{
ArrayType, FloatType, FunctionType, IntType, PointerType, ScalableVectorType, StructType, VectorType, VoidType,
};
use crate::values::{BasicValue, BasicValueEnum, IntValue};

use std::convert::TryFrom;
Expand Down Expand Up @@ -70,6 +72,8 @@ enum_type_set! {
StructType,
/// A contiguous homogeneous "SIMD" container type.
VectorType,
/// A contiguous homogenous scalable "SIMD" container type.
ScalableVectorType,
/// A valueless type.
VoidType,
}
Expand All @@ -89,6 +93,8 @@ enum_type_set! {
StructType,
/// A contiguous homogeneous "SIMD" container type.
VectorType,
/// A contiguous homogenous scalable "SIMD" container type.
ScalableVectorType,
}
}
enum_type_set! {
Expand All @@ -99,6 +105,7 @@ enum_type_set! {
PointerType,
StructType,
VectorType,
ScalableVectorType,
MetadataType,
}
}
Expand Down Expand Up @@ -152,6 +159,14 @@ impl<'ctx> BasicMetadataTypeEnum<'ctx> {
}
}

pub fn into_scalable_vector_type(self) -> ScalableVectorType<'ctx> {
if let BasicMetadataTypeEnum::ScalableVectorType(t) = self {
t
} else {
panic!("Found {:?} but expected another variant", self);
}
}

pub fn into_metadata_type(self) -> MetadataType<'ctx> {
if let BasicMetadataTypeEnum::MetadataType(t) = self {
t
Expand Down Expand Up @@ -188,6 +203,10 @@ impl<'ctx> BasicMetadataTypeEnum<'ctx> {
matches!(self, BasicMetadataTypeEnum::VectorType(_))
}

pub fn is_scalable_vector_type(self) -> bool {
matches!(self, BasicMetadataTypeEnum::ScalableVectorType(_))
}

/// Print the definition of a `BasicMetadataTypeEnum` to `LLVMString`.
pub fn print_to_string(self) -> LLVMString {
match self {
Expand All @@ -197,6 +216,7 @@ impl<'ctx> BasicMetadataTypeEnum<'ctx> {
BasicMetadataTypeEnum::PointerType(t) => t.print_to_string(),
BasicMetadataTypeEnum::StructType(t) => t.print_to_string(),
BasicMetadataTypeEnum::VectorType(t) => t.print_to_string(),
BasicMetadataTypeEnum::ScalableVectorType(t) => t.print_to_string(),
BasicMetadataTypeEnum::MetadataType(t) => t.print_to_string(),
}
}
Expand Down Expand Up @@ -244,7 +264,7 @@ impl<'ctx> AnyTypeEnum<'ctx> {
feature = "llvm17-0",
feature = "llvm18-0"
))]
LLVMTypeKind::LLVMScalableVectorTypeKind => AnyTypeEnum::VectorType(VectorType::new(type_)),
LLVMTypeKind::LLVMScalableVectorTypeKind => AnyTypeEnum::ScalableVectorType(ScalableVectorType::new(type_)),
// FIXME: should inkwell support metadata as AnyType?
LLVMTypeKind::LLVMMetadataTypeKind => panic!("Metadata type is not supported as AnyType."),
LLVMTypeKind::LLVMX86_MMXTypeKind => panic!("FIXME: Unsupported type: MMX"),
Expand Down Expand Up @@ -325,6 +345,14 @@ impl<'ctx> AnyTypeEnum<'ctx> {
}
}

pub fn into_scalable_vector_type(self) -> ScalableVectorType<'ctx> {
if let AnyTypeEnum::ScalableVectorType(t) = self {
t
} else {
panic!("Found {:?} but expected the ScalableVectorType variant", self);
}
}

pub fn into_void_type(self) -> VoidType<'ctx> {
if let AnyTypeEnum::VoidType(t) = self {
t
Expand Down Expand Up @@ -373,6 +401,7 @@ impl<'ctx> AnyTypeEnum<'ctx> {
AnyTypeEnum::PointerType(t) => Some(t.size_of()),
AnyTypeEnum::StructType(t) => t.size_of(),
AnyTypeEnum::VectorType(t) => t.size_of(),
AnyTypeEnum::ScalableVectorType(t) => t.size_of(),
AnyTypeEnum::VoidType(_) => None,
AnyTypeEnum::FunctionType(_) => None,
}
Expand All @@ -387,6 +416,7 @@ impl<'ctx> AnyTypeEnum<'ctx> {
AnyTypeEnum::PointerType(t) => t.print_to_string(),
AnyTypeEnum::StructType(t) => t.print_to_string(),
AnyTypeEnum::VectorType(t) => t.print_to_string(),
AnyTypeEnum::ScalableVectorType(t) => t.print_to_string(),
AnyTypeEnum::VoidType(t) => t.print_to_string(),
AnyTypeEnum::FunctionType(t) => t.print_to_string(),
}
Expand Down Expand Up @@ -432,7 +462,9 @@ impl<'ctx> BasicTypeEnum<'ctx> {
feature = "llvm17-0",
feature = "llvm18-0"
))]
LLVMTypeKind::LLVMScalableVectorTypeKind => BasicTypeEnum::VectorType(VectorType::new(type_)),
LLVMTypeKind::LLVMScalableVectorTypeKind => {
BasicTypeEnum::ScalableVectorType(ScalableVectorType::new(type_))
},
LLVMTypeKind::LLVMMetadataTypeKind => panic!("Unsupported basic type: Metadata"),
// see https://llvm.org/docs/LangRef.html#x86-mmx-type
LLVMTypeKind::LLVMX86_MMXTypeKind => panic!("Unsupported basic type: MMX"),
Expand Down Expand Up @@ -504,6 +536,14 @@ impl<'ctx> BasicTypeEnum<'ctx> {
}
}

pub fn into_scalable_vector_type(self) -> ScalableVectorType<'ctx> {
if let BasicTypeEnum::ScalableVectorType(t) = self {
t
} else {
panic!("Found {:?} but expected the ScalableVectorType variant", self);
}
}

pub fn is_array_type(self) -> bool {
matches!(self, BasicTypeEnum::ArrayType(_))
}
Expand All @@ -528,6 +568,10 @@ impl<'ctx> BasicTypeEnum<'ctx> {
matches!(self, BasicTypeEnum::VectorType(_))
}

pub fn is_scalable_vector_type(self) -> bool {
matches!(self, BasicTypeEnum::ScalableVectorType(_))
}

/// Creates a constant `BasicValueZero`.
///
/// # Example
Expand All @@ -547,6 +591,7 @@ impl<'ctx> BasicTypeEnum<'ctx> {
BasicTypeEnum::PointerType(ty) => ty.const_zero().as_basic_value_enum(),
BasicTypeEnum::StructType(ty) => ty.const_zero().as_basic_value_enum(),
BasicTypeEnum::VectorType(ty) => ty.const_zero().as_basic_value_enum(),
BasicTypeEnum::ScalableVectorType(ty) => ty.const_zero().as_basic_value_enum(),
}
}

Expand All @@ -559,6 +604,7 @@ impl<'ctx> BasicTypeEnum<'ctx> {
BasicTypeEnum::PointerType(t) => t.print_to_string(),
BasicTypeEnum::StructType(t) => t.print_to_string(),
BasicTypeEnum::VectorType(t) => t.print_to_string(),
BasicTypeEnum::ScalableVectorType(t) => t.print_to_string(),
}
}
}
Expand All @@ -575,6 +621,7 @@ impl<'ctx> TryFrom<AnyTypeEnum<'ctx>> for BasicTypeEnum<'ctx> {
PointerType(pt) => pt.into(),
StructType(st) => st.into(),
VectorType(vt) => vt.into(),
ScalableVectorType(vt) => vt.into(),
VoidType(_) | FunctionType(_) => return Err(()),
})
}
Expand All @@ -592,6 +639,7 @@ impl<'ctx> TryFrom<AnyTypeEnum<'ctx>> for BasicMetadataTypeEnum<'ctx> {
PointerType(pt) => pt.into(),
StructType(st) => st.into(),
VectorType(vt) => vt.into(),
ScalableVectorType(vt) => vt.into(),
VoidType(_) | FunctionType(_) => return Err(()),
})
}
Expand All @@ -609,6 +657,7 @@ impl<'ctx> TryFrom<BasicMetadataTypeEnum<'ctx>> for BasicTypeEnum<'ctx> {
PointerType(pt) => pt.into(),
StructType(st) => st.into(),
VectorType(vt) => vt.into(),
ScalableVectorType(vt) => vt.into(),
MetadataType(_) => return Err(()),
})
}
Expand All @@ -624,6 +673,7 @@ impl<'ctx> From<BasicTypeEnum<'ctx>> for BasicMetadataTypeEnum<'ctx> {
PointerType(pt) => pt.into(),
StructType(st) => st.into(),
VectorType(vt) => vt.into(),
ScalableVectorType(vt) => vt.into(),
}
}
}
Expand Down
33 changes: 26 additions & 7 deletions src/types/float_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::context::ContextRef;
use crate::support::LLVMString;
use crate::types::enums::BasicMetadataTypeEnum;
use crate::types::traits::AsTypeRef;
use crate::types::{ArrayType, FunctionType, PointerType, Type, VectorType};
use crate::types::{ArrayType, FunctionType, PointerType, ScalableVectorType, Type, VectorType};
use crate::values::{ArrayValue, FloatValue, GenericValue, IntValue};
use crate::AddressSpace;

Expand Down Expand Up @@ -64,7 +64,7 @@ impl<'ctx> FloatType<'ctx> {
self.float_type.array_type(size)
}

/// Creates a `VectorType` with this `FloatType` for its element type.
/// Creates a `ScalableVectorType` with this `FloatType` for its element type.
///
/// # Example
///
Expand All @@ -73,15 +73,34 @@ impl<'ctx> FloatType<'ctx> {
///
/// let context = Context::create();
/// let f32_type = context.f32_type();
/// let f32_vector_type = f32_type.vec_type(3);
/// let f32_scalable_vector_type = f32_type.vec_type(3);
///
/// assert_eq!(f32_vector_type.get_size(), 3);
/// assert_eq!(f32_vector_type.get_element_type().into_float_type(), f32_type);
/// assert_eq!(f32_scalable_vector_type.get_size(), 3);
/// assert_eq!(f32_scalable_vector_type.get_element_type().into_float_type(), f32_type);
/// ```
pub fn vec_type(self, size: u32) -> VectorType<'ctx> {
self.float_type.vec_type(size)
}

/// Creates a scalable `VectorType` with this `FloatType` for its element type.
///
/// # Example
///
/// ```no_run
/// use inkwell::context::Context;
///
/// let context = Context::create();
/// let f32_type = context.f32_type();
/// let f32_vector_type = f32_type.scalable_vec_type(3);
///
/// assert_eq!(f32_vector_type.get_size(), 3);
/// assert_eq!(f32_vector_type.get_element_type().into_float_type(), f32_type);
/// ```
#[llvm_versions(12..)]
pub fn scalable_vec_type(self, size: u32) -> ScalableVectorType<'ctx> {
self.float_type.scalable_vec_type(size)
}

/// Creates a `FloatValue` representing a constant value of this `FloatType`.
/// It will be automatically assigned this `FloatType`'s `Context`.
///
Expand Down Expand Up @@ -128,9 +147,9 @@ impl<'ctx> FloatType<'ctx> {
/// assert_eq!(f64_val.print_to_string().to_string(), "double 0x7FF0000000000000");
/// ```
pub unsafe fn const_float_from_string(self, slice: &str) -> FloatValue<'ctx> {
assert!(!slice.is_empty());
assert!(!slice.is_empty());

unsafe {
unsafe {
FloatValue::new(LLVMConstRealOfStringAndSize(
self.as_type_ref(),
slice.as_ptr() as *const ::libc::c_char,
Expand Down
21 changes: 20 additions & 1 deletion src/types/int_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use llvm_sys::prelude::LLVMTypeRef;
use crate::context::ContextRef;
use crate::support::LLVMString;
use crate::types::traits::AsTypeRef;
use crate::types::{ArrayType, FunctionType, PointerType, Type, VectorType};
use crate::types::{ArrayType, FunctionType, PointerType, ScalableVectorType, Type, VectorType};
use crate::values::{ArrayValue, GenericValue, IntValue};
use crate::AddressSpace;

Expand Down Expand Up @@ -244,6 +244,25 @@ impl<'ctx> IntType<'ctx> {
self.int_type.vec_type(size)
}

/// Creates a `ScalableVectorType` with this `IntType` for its element type.
///
/// # Example
///
/// ```no_run
/// use inkwell::context::Context;
///
/// let context = Context::create();
/// let i8_type = context.i8_type();
/// let i8_scalable_vector_type = i8_type.scalable_vec_type(3);
///
/// assert_eq!(i8_scalable_vector_type.get_size(), 3);
/// assert_eq!(i8_scalable_vector_type.get_element_type().into_int_type(), i8_type);
/// ```
#[llvm_versions(12..)]
pub fn scalable_vec_type(self, size: u32) -> ScalableVectorType<'ctx> {
self.int_type.scalable_vec_type(size)
}

/// Gets a reference to the `Context` this `IntType` was created in.
///
/// # Example
Expand Down
Loading

0 comments on commit 8bd9f08

Please sign in to comment.