Skip to content

Commit

Permalink
Merge pull request #47 from CodeSandwich/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
CodeSandwich authored Jun 13, 2019
2 parents 25b74a5 + 3f26cad commit 00e4fbc
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 120 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,4 @@ pub mod macros {
pub use mocktopus_macros::*;
}


mod mock_store;
141 changes: 141 additions & 0 deletions src/mock_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use crate::mocking::MockResult;
use std::any::TypeId;
use std::cell::RefCell;
use std::collections::HashMap;
use std::mem::transmute;
use std::rc::Rc;

pub struct MockStore {
layers: RefCell<Vec<MockLayer>>,
}

impl MockStore {
pub fn clear(&self) {
for layer in self.layers.borrow_mut().iter_mut() {
layer.clear()
}
}

pub fn clear_id(&self, id: TypeId) {
for layer in self.layers.borrow_mut().iter_mut() {
layer.clear_id(id)
}
}

/// Layer will be in use as long as MockLayerGuard is alive
/// MockLayerGuards must always be dropped and always in reverse order of their creation
pub unsafe fn add_layer(&self, layer: MockLayer) {
self.layers.borrow_mut().push(layer)
}

pub unsafe fn remove_layer(&self) {
self.layers.borrow_mut().pop();
}

pub unsafe fn add_to_thread_layer<I, O>(
&self, id: TypeId, mock: Box<FnMut<I, Output=MockResult<I, O>> + 'static>) {
self.layers.borrow_mut().first_mut().expect("Thread mock level missing").add(id, mock);
}

pub unsafe fn call<I, O>(&self, id: TypeId, mut input: I) -> MockResult<I, O> {
// Do not hold RefCell borrow while calling mock, it can try to modify mocks
let layer_count = self.layers.borrow().len();
for layer_idx in (0..layer_count).rev() {
let mock_opt = self.layers.borrow()
.get(layer_idx)
.expect("Mock layer removed while iterating")
.get(id);
if let Some(mock) = mock_opt {
match mock.call(input) {
MockLayerResult::Handled(result) => return result,
MockLayerResult::Unhandled(new_input) => input = new_input,
}
}
}
MockResult::Continue(input)
}
}

//TODO tests
// clear
// clear id
// add and remove layer
// inside mock closure

impl Default for MockStore {
fn default() -> Self {
MockStore {
layers: RefCell::new(vec![MockLayer::default()]),
}
}
}

#[derive(Default)]
pub struct MockLayer {
mocks: HashMap<TypeId, ErasedStoredMock>,
}

impl MockLayer {
fn clear(&mut self) {
self.mocks.clear()
}

fn clear_id(&mut self, id: TypeId) {
self.mocks.remove(&id);
}

pub unsafe fn add<I, O>(&mut self, id: TypeId, mock: Box<FnMut<I, Output=MockResult<I, O>> + 'static>) {
let stored = StoredMock::new(mock).erase();
self.mocks.insert(id, stored);
}

unsafe fn get(&self, id: TypeId) -> Option<ErasedStoredMock> {
self.mocks.get(&id).cloned()
}
}

pub enum MockLayerResult<I, O> {
Handled(MockResult<I, O>),
Unhandled(I),
}

#[derive(Clone)]
struct ErasedStoredMock {
mock: StoredMock<(), ()>,
}

impl ErasedStoredMock {
unsafe fn call<I, O>(self, input: I) -> MockLayerResult<I, O> {
let unerased: StoredMock<I, O> = transmute(self.mock);
unerased.call(input)
}
}

/// Guarantees that while mock is running it's not overwritten, destroyed, or called again
#[derive(Clone)]
struct StoredMock<I, O> {
mock: Rc<RefCell<Box<FnMut<I, Output=MockResult<I, O>>>>>
}

impl<I, O> StoredMock<I, O> {
fn new(mock: Box<FnMut<I, Output=MockResult<I, O>> + 'static>) -> Self {
StoredMock {
mock: Rc::new(RefCell::new(mock))
}
}

fn call(&self, input: I) -> MockLayerResult<I, O> {
match self.mock.try_borrow_mut() {
Ok(mut mock) => MockLayerResult::Handled(mock.call_mut(input)),
Err(_) => MockLayerResult::Unhandled(input),
}
}

fn erase(self) -> ErasedStoredMock {
unsafe {
ErasedStoredMock {
mock: transmute(self),
}
}
}
}
117 changes: 39 additions & 78 deletions src/mocking.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use crate::mock_store::{MockLayer, MockStore};
use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::mem::transmute;
use std::rc::Rc;

/// Trait for setting up mocks
///
Expand Down Expand Up @@ -90,55 +88,20 @@ pub enum MockResult<T, O> {
}

thread_local!{
static MOCK_STORE: RefCell<HashMap<TypeId, Rc<RefCell<Box<FnMut<(), Output=()>>>>>> = RefCell::new(HashMap::new())
static MOCK_STORE: MockStore = MockStore::default()
}

/// Clear all mocks in the ThreadLocal; only necessary if tests share threads
pub fn clear_mocks() {
MOCK_STORE.with(|mock_ref_cell| {
mock_ref_cell.borrow_mut().clear();
});
}

struct ScopedMock<'a> {
phantom: PhantomData<&'a ()>,
id: TypeId,
}

impl<'a> ScopedMock<'a> {
unsafe fn new<T, O, M: Mockable<T, O> + 'a, F: FnMut<T, Output=MockResult<T, O>>>(
mockable: &M,
mock: F,
) -> Self {
mockable.mock_raw(mock);
ScopedMock {
phantom: PhantomData,
id: mockable.get_mock_id(),
}
}
}

impl<'a> Drop for ScopedMock<'a> {
fn drop(&mut self) {
clear_id(self.id);
}
}

fn clear_id(id: TypeId) {
MOCK_STORE.with(|mock_ref_cell| {
mock_ref_cell.borrow_mut().remove(&id);
});
MOCK_STORE.with(|mock_store| mock_store.clear())
}

impl<T, O, F: FnOnce<T, Output=O>> Mockable<T, O> for F {
unsafe fn mock_raw<M: FnMut<T, Output=MockResult<T, O>>>(&self, mock: M) {
let id = self.get_mock_id();
MOCK_STORE.with(|mock_ref_cell| {
let real = Rc::new(RefCell::new(Box::new(mock) as Box<FnMut<_, Output=_>>));
let stored = transmute(real);
mock_ref_cell.borrow_mut()
.insert(id, stored);
})
let boxed = Box::new(mock) as Box::<FnMut<_, Output = _>>;
let static_boxed: Box<FnMut<T, Output = MockResult<T, O>> + 'static> = transmute(boxed);
MOCK_STORE.with(|mock_store| mock_store.add_to_thread_layer(id, static_boxed))
}

fn mock_safe<M: FnMut<T, Output=MockResult<T, O>> + 'static>(&self, mock: M) {
Expand All @@ -149,26 +112,13 @@ impl<T, O, F: FnOnce<T, Output=O>> Mockable<T, O> for F {

fn clear_mock(&self) {
let id = unsafe { self.get_mock_id() };
clear_id(id);
MOCK_STORE.with(|mock_store| mock_store.clear_id(id))
}

fn call_mock(&self, input: T) -> MockResult<T, O> {
unsafe {
let id = self.get_mock_id();
let rc_opt = MOCK_STORE.with(|mock_ref_cell|
mock_ref_cell.borrow()
.get(&id)
.cloned()
);
let stored_opt = rc_opt.as_ref()
.and_then(|rc| rc.try_borrow_mut().ok());
match stored_opt {
Some(mut stored) => {
let real: &mut Box<FnMut<_, Output=_>> = transmute(&mut*stored);
real.call_mut(input)
}
None => MockResult::Continue(input),
}
MOCK_STORE.with(|mock_store| mock_store.call(id, input))
}
}

Expand Down Expand Up @@ -223,7 +173,8 @@ impl<T, O, F: FnOnce<T, Output=O>> Mockable<T, O> for F {
/// ```
#[derive(Default)]
pub struct MockContext<'a> {
planned_mocks: HashMap<TypeId, Box<FnOnce() -> ScopedMock<'a> + 'a>>,
mock_layer: MockLayer,
phantom_lifetime: PhantomData<&'a ()>,
}

impl<'a> MockContext<'a> {
Expand All @@ -236,20 +187,23 @@ impl<'a> MockContext<'a> {
///
/// This function doesn't actually mock the function. It registers it as a
/// function that will be mocked when [`run`](#method.run) is called.
pub fn mock_safe<
Args,
Output,
M: Mockable<Args, Output> + 'a,
F: FnMut<Args, Output = MockResult<Args, Output>> + 'a,
>(
mut self,
mock: M,
body: F,
) -> Self {
self.planned_mocks.insert(
unsafe { mock.get_mock_id() },
Box::new(move || unsafe { ScopedMock::new(&mock, body) }),
);
pub fn mock_safe<I, O, F, M>(self, mockable: F, mock: M) -> Self
where F: Mockable<I, O>, M: FnMut<I, Output = MockResult<I, O>> + 'a {
unsafe {
self.mock_raw(mockable, mock)
}
}

/// Set up a function to be mocked.
///
/// This is an unsafe version of [`mock_safe`](#method.mock_safe),
/// without lifetime constraint on mock
pub unsafe fn mock_raw<I, O, F, M>(mut self, mockable: F, mock: M) -> Self
where F: Mockable<I, O>, M: FnMut<I, Output = MockResult<I, O>> {
let mock_box = Box::new(mock) as Box<FnMut<_, Output = _>>;
let mock_box_static: Box<FnMut<I, Output = MockResult<I, O>> + 'static>
= std::mem::transmute(mock_box);
self.mock_layer.add(mockable.get_mock_id(), mock_box_static);
self
}

Expand All @@ -262,11 +216,18 @@ impl<'a> MockContext<'a> {
///
/// Register a function for mocking with [`mock_safe`](#method.mock_safe).
pub fn run<T, F: FnOnce() -> T>(self, f: F) -> T {
let _scoped_mocks = self
.planned_mocks
.into_iter()
.map(|entry| entry.1())
.collect::<Vec<_>>();
MOCK_STORE.with(|mock_store| unsafe { mock_store.add_layer(self.mock_layer) });
let _mock_level_guard = MockLayerGuard;
f()
}
}

struct MockLayerGuard;

impl<'a> Drop for MockLayerGuard {
fn drop(&mut self) {
MOCK_STORE.with(|mock_store| unsafe {
mock_store.remove_layer()
});
}
}
Loading

0 comments on commit 00e4fbc

Please sign in to comment.