Skip to content

Commit

Permalink
feat: shorthand for retrieving custom constants from Const, Value (
Browse files Browse the repository at this point in the history
…#679)

Closes #654
  • Loading branch information
ss2165 authored Nov 13, 2023
1 parent 78faf6d commit 144e91f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ impl Const {
.unzip();
Self::new(Value::tuple(values), Type::new_tuple(types)).unwrap()
}

/// For a Const holding a CustomConst, extract the CustomConst by downcasting.
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
self.value().get_custom_value()
}
}

impl OpName for Const {
Expand Down Expand Up @@ -123,7 +128,7 @@ mod test {
prelude::{ConstUsize, USIZE_T},
ExtensionId, ExtensionSet,
},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE},
type_row,
types::test::test_registry,
types::type_param::TypeArg,
Expand Down Expand Up @@ -197,11 +202,23 @@ mod test {
tuple_ty.check_type(&tuple_val2),
Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2
);
let tuple_val3 = Value::tuple([int_value, serialized_float(3.3), serialized_float(2.0)]);
let tuple_val3 = Value::tuple([
int_value.clone(),
serialized_float(3.3),
serialized_float(2.0),
]);
assert_eq!(
tuple_ty.check_type(&tuple_val3),
Err(ConstTypeError::TupleWrongLength)
);

let op = Const::new(int_value, USIZE_T).unwrap();

assert_eq!(op.get_custom_value(), Some(&ConstUsize::new(257)));
let try_float: Option<&ConstF64> = op.get_custom_value();
assert!(try_float.is_none());
let try_usize: Option<&ConstUsize> = tuple_val.get_custom_value();
assert!(try_usize.is_none());
}

#[test]
Expand Down
12 changes: 12 additions & 0 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ impl Value {
val: PrimValue::Extension { c: (Box::new(c),) },
}
}

/// For a Const holding a CustomConst, extract the CustomConst by downcasting.
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
if let Value::Prim {
val: PrimValue::Extension { c: (custom,) },
} = self
{
custom.downcast_ref()
} else {
None
}
}
}

impl<T: CustomConst> From<T> for Value {
Expand Down

0 comments on commit 144e91f

Please sign in to comment.