Skip to content

Commit

Permalink
Add XorShiftRng serde
Browse files Browse the repository at this point in the history
  • Loading branch information
UserAB1236872 committed Oct 27, 2017
1 parent cbbee62 commit 24842c7
Showing 1 changed file with 183 additions and 0 deletions.
183 changes: 183 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ use std::mem;
use std::io;
use std::rc::Rc;
use std::num::Wrapping as w;
#[cfg(feature="serde-1")]
use std::fmt;

#[cfg(feature = "serde-1")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature="serde-1")]
use serde::de::Visitor;


pub use os::OsRng;

Expand Down Expand Up @@ -804,6 +812,154 @@ impl Rand for XorShiftRng {
}
}

#[cfg(feature = "serde-1")]
impl Serialize for XorShiftRng {
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeStruct;

let mut state = ser.serialize_struct("XorShiftRng",6)?;

let w(x) = self.x;
state.serialize_field("x", &x)?;

let w(y) = self.y;
state.serialize_field("y", &y)?;

let w(z) = self.z;
state.serialize_field("z", &z)?;

let w(w_field) = self.w;
state.serialize_field("w", &w_field)?;

state.end()
}
}

#[cfg(feature="serde-1")]
impl<'de> Deserialize<'de> for XorShiftRng {
fn deserialize<D>(de: D) -> Result<XorShiftRng, D::Error>
where D: Deserializer<'de> {
use serde::de::{SeqAccess,MapAccess};
use serde::de;

enum Field { X, Y, Z, W };

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where D: Deserializer<'de> {
struct XorFieldVisitor;
impl<'de> Visitor<'de> for XorFieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`x`, `y`, `z`, or `w`")
}

fn visit_str<E>(self, value: &str) -> Result<Field,E>
where E: de::Error {
match value {
"x" => Ok(Field::X),
"y" => Ok(Field::Y),
"z" => Ok(Field::Z),
"w" => Ok(Field::W),
_ => Err(de::Error::unknown_field(value, FIELDS))
}
}
}
deserializer.deserialize_identifier(XorFieldVisitor)
}
}

struct XorVisitor;

const FIELDS: &[&'static str] = &["x", "y", "z", "w"];

impl<'de> Visitor<'de> for XorVisitor {
type Value = XorShiftRng;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct XorShiftRng")
}

fn visit_seq<V>(self, mut seq: V) -> Result<XorShiftRng, V::Error>
where V: SeqAccess<'de> {
let x: u32 = seq.next_element()?
.ok_or_else(|| de::Error::invalid_length(0,&self))?;

let y: u32 = seq.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;

let z: u32 = seq.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;

let w_field: u32 = seq.next_element()?
.ok_or_else(|| de::Error::invalid_length(3, &self))?;


let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));

Ok(XorShiftRng {
x: x,y: y,z: z,w: w_field
})
}

fn visit_map<V>(self, mut map: V) -> Result<XorShiftRng, V::Error>
where V: MapAccess<'de>
{
let mut x = None;
let mut y = None;
let mut z = None;
let mut w_field = None;

while let Some(key) = map.next_key()? {
match key {
Field::X => {
if x.is_some() {
return Err(de::Error::duplicate_field("x"));
}
x = Some(map.next_value()?);
}
Field::Y => {
if y.is_some() {
return Err(de::Error::duplicate_field("y"));
}
y = Some(map.next_value()?);
}
Field::Z => {
if z.is_some() {
return Err(de::Error::duplicate_field("z"));
}
z = Some(map.next_value()?);
}
Field::W => {
if w_field.is_some() {
return Err(de::Error::duplicate_field("w"));
}
w_field = Some(map.next_value()?);
}
}
}

let x = x.ok_or_else(|| de::Error::missing_field("x"))?;
let y = y.ok_or_else(|| de::Error::missing_field("y"))?;
let z = z.ok_or_else(|| de::Error::missing_field("z"))?;
let w_field = w_field.ok_or_else(|| de::Error::missing_field("w"))?;

let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));

Ok(XorShiftRng {
x: x,y: y,z: z,w: w_field
})
}
}

de.deserialize_struct("IsaacRng", FIELDS, XorVisitor)
}
}

/// A wrapper for generating floating point numbers uniformly in the
/// open interval `(0,1)` (not including either endpoint).
///
Expand Down Expand Up @@ -1311,4 +1467,31 @@ mod test {
assert_eq!(rng.next_u64(), deserialized.next_u64());
}
}

#[test]
fn test_xor_serde() {
use super::XorShiftRng;
use bincode;
use std::io::{BufWriter, BufReader};

let seed: [u32; 4] = thread_rng().gen();
let mut rng: XorShiftRng = SeedableRng::from_seed(seed);

let buf: Vec<u8> = Vec::new();
let mut buf = BufWriter::new(buf);
bincode::serialize_into(&mut buf, &rng, bincode::Infinite).expect("Could not serialize");

let buf = buf.into_inner().unwrap();
let mut read = BufReader::new(&buf[..]);
let mut deserialized: XorShiftRng = bincode::deserialize_from(&mut read, bincode::Infinite).expect("Could not deserialize");

assert_eq!(rng.x, deserialized.x);
assert_eq!(rng.y, deserialized.y);
assert_eq!(rng.z, deserialized.z);
assert_eq!(rng.w, deserialized.w);

for _ in 0..16 {
assert_eq!(rng.next_u64(), deserialized.next_u64());
}
}
}

0 comments on commit 24842c7

Please sign in to comment.