Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing iterators #461

Merged
merged 1 commit into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ impl<'ctx> BasicBlock<'ctx> {
unsafe { Some(InstructionValue::new(value)) }
}

/// Get an instruction iterator
pub fn get_instructions(self) -> InstructionIter<'ctx> {
InstructionIter(self.get_first_instruction())
}

/// Removes this `BasicBlock` from its parent `FunctionValue`.
/// It returns `Err(())` when it has no parent to remove from.
///
Expand Down Expand Up @@ -597,3 +602,20 @@ impl fmt::Debug for BasicBlock<'_> {
.finish()
}
}

/// Iterate over all `InstructionValue`s in a basic block.
#[derive(Debug)]
pub struct InstructionIter<'ctx>(Option<InstructionValue<'ctx>>);

impl<'ctx> Iterator for InstructionIter<'ctx> {
type Item = InstructionValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(instr) = self.0 {
self.0 = instr.get_next_instruction();
Some(instr)
} else {
None
}
}
}
78 changes: 14 additions & 64 deletions src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1573,96 +1573,46 @@ pub enum FlagBehavior {

/// Iterate over all `FunctionValue`s in an llvm module
#[derive(Debug)]
pub struct FunctionIterator<'ctx>(FunctionIteratorInner<'ctx>);

/// Inner type so the variants are not publicly visible
#[derive(Debug)]
enum FunctionIteratorInner<'ctx> {
Empty,
Start(FunctionValue<'ctx>),
Previous(FunctionValue<'ctx>),
}
pub struct FunctionIterator<'ctx>(Option<FunctionValue<'ctx>>);

impl<'ctx> FunctionIterator<'ctx> {
fn from_module(module: &Module<'ctx>) -> Self {
use FunctionIteratorInner::*;

match module.get_first_function() {
None => Self(Empty),
Some(first) => Self(Start(first)),
}
Self(module.get_first_function())
}
}

impl<'ctx> Iterator for FunctionIterator<'ctx> {
type Item = FunctionValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
use FunctionIteratorInner::*;

match self.0 {
Empty => None,
Start(first) => {
self.0 = Previous(first);

Some(first)
},
Previous(prev) => match prev.get_next_function() {
Some(current) => {
self.0 = Previous(current);

Some(current)
},
None => None,
},
if let Some(func) = self.0 {
self.0 = func.get_next_function();
Some(func)
} else {
None
}
}
}

/// Iterate over all `GlobalValue`s in an llvm module
#[derive(Debug)]
pub struct GlobalIterator<'ctx>(GlobalIteratorInner<'ctx>);

/// Inner type so the variants are not publicly visible
#[derive(Debug)]
enum GlobalIteratorInner<'ctx> {
Empty,
Start(GlobalValue<'ctx>),
Previous(GlobalValue<'ctx>),
}
pub struct GlobalIterator<'ctx>(Option<GlobalValue<'ctx>>);

impl<'ctx> GlobalIterator<'ctx> {
fn from_module(module: &Module<'ctx>) -> Self {
use GlobalIteratorInner::*;

match module.get_first_global() {
None => Self(Empty),
Some(first) => Self(Start(first)),
}
Self(module.get_first_global())
}
}

impl<'ctx> Iterator for GlobalIterator<'ctx> {
type Item = GlobalValue<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
use GlobalIteratorInner::*;

match self.0 {
Empty => None,
Start(first) => {
self.0 = Previous(first);

Some(first)
},
Previous(prev) => match prev.get_next_global() {
Some(current) => {
self.0 = Previous(current);

Some(current)
},
None => None,
},
if let Some(global) = self.0 {
self.0 = global.get_next_global();
Some(global)
} else {
None
}
}
}
21 changes: 21 additions & 0 deletions src/values/fn_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ impl<'ctx> FunctionValue<'ctx> {
unsafe { LLVMCountBasicBlocks(self.as_value_ref()) }
}

pub fn get_basic_block_iter(self) -> BasicBlockIter<'ctx> {
BasicBlockIter(self.get_first_basic_block())
}

pub fn get_basic_blocks(self) -> Vec<BasicBlock<'ctx>> {
let count = self.count_basic_blocks();
let mut raw_vec: Vec<LLVMBasicBlockRef> = Vec::with_capacity(count as usize);
Expand Down Expand Up @@ -552,6 +556,23 @@ impl fmt::Debug for FunctionValue<'_> {
}
}

/// Iterate over all `BasicBlock`s in a function.
#[derive(Debug)]
pub struct BasicBlockIter<'ctx>(Option<BasicBlock<'ctx>>);

impl<'ctx> Iterator for BasicBlockIter<'ctx> {
type Item = BasicBlock<'ctx>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(bb) = self.0 {
self.0 = bb.get_next_basic_block();
Some(bb)
} else {
None
}
}
}

#[derive(Debug)]
pub struct ParamValueIter<'ctx> {
param_iter_value: LLVMValueRef,
Expand Down
50 changes: 26 additions & 24 deletions tests/all/test_basic_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,33 @@ fn test_basic_block_ordering() {
let basic_block2 = context.insert_basic_block_after(basic_block, "block2");
let basic_block3 = context.prepend_basic_block(basic_block4, "block3");

let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 4);
assert_eq!(basic_blocks[0], basic_block);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block3);
assert_eq!(basic_blocks[3], basic_block4);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 4);
assert_eq!(basic_blocks[0], basic_block);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block3);
assert_eq!(basic_blocks[3], basic_block4);
}

assert!(basic_block3.move_before(basic_block2).is_ok());
assert!(basic_block.move_after(basic_block4).is_ok());

let basic_block5 = context.prepend_basic_block(basic_block, "block5");
let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 5);
assert_eq!(basic_blocks[0], basic_block3);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block4);
assert_eq!(basic_blocks[3], basic_block5);
assert_eq!(basic_blocks[4], basic_block);

assert_ne!(basic_blocks[0], basic_block);
assert_ne!(basic_blocks[1], basic_block3);
assert_ne!(basic_blocks[2], basic_block2);
assert_ne!(basic_blocks[3], basic_block4);
assert_ne!(basic_blocks[4], basic_block5);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 5);
assert_eq!(basic_blocks[0], basic_block3);
assert_eq!(basic_blocks[1], basic_block2);
assert_eq!(basic_blocks[2], basic_block4);
assert_eq!(basic_blocks[3], basic_block5);
assert_eq!(basic_blocks[4], basic_block);

assert_ne!(basic_blocks[0], basic_block);
assert_ne!(basic_blocks[1], basic_block3);
assert_ne!(basic_blocks[2], basic_block2);
assert_ne!(basic_blocks[3], basic_block4);
assert_ne!(basic_blocks[4], basic_block5);
}

context.append_basic_block(function, "block6");

Expand Down Expand Up @@ -89,6 +90,7 @@ fn test_get_basic_blocks() {

assert!(function.get_last_basic_block().is_none());
assert_eq!(function.get_basic_blocks().len(), 0);
assert_eq!(function.get_basic_block_iter().count(), 0);

let basic_block = context.append_basic_block(function, "entry");

Expand All @@ -98,10 +100,10 @@ fn test_get_basic_blocks() {

assert_eq!(last_basic_block, basic_block);

let basic_blocks = function.get_basic_blocks();

assert_eq!(basic_blocks.len(), 1);
assert_eq!(basic_blocks[0], basic_block);
for basic_blocks in [function.get_basic_blocks(), function.get_basic_block_iter().collect()] {
assert_eq!(basic_blocks.len(), 1);
assert_eq!(basic_blocks[0], basic_block);
}
}

#[test]
Expand Down