Skip to content

Commit

Permalink
Added JQ syntax for replacements + min/max score. (#133)
Browse files Browse the repository at this point in the history
* added score, bugfix name

* allow jq in spans, min score

* e2e tests
  • Loading branch information
soldni authored Mar 6, 2024
1 parent cd7d983 commit a75bf09
Show file tree
Hide file tree
Showing 6 changed files with 456 additions and 112 deletions.
30 changes: 26 additions & 4 deletions python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ class FilterConfig:
@dataclass
class SpanReplacementConfig:
span: str = field(help="JSONPath expression for the span to replace")
min_score: float = field(default=0.5, help="Minimum score for the span to be replaced")
min_score: Optional[float] = field(
default=None,
help="Minimum score for the span to be replaced. Either min_score or max_score must be specified.",
)
max_score: Optional[float] = field(
default=None,
help="Maximum score for the span to be replaced. Either min_score or max_score must be specified.",
)
replacement: str = field(default="", help="Replacement for the span")
syntax: str = field(
default="jsonpath",
Expand Down Expand Up @@ -97,15 +104,30 @@ def run(cls, parsed_config: MixerConfig):
}

for span_replacement in stream_config.span_replacement:
if span_replacement.syntax not in ["jsonpath"]:
raise DolmaConfigError("Invalid span_replacement syntax; must be 'jsonpath'")
if span_replacement.syntax not in ["jsonpath", "jq"]:
raise DolmaConfigError("Invalid span_replacement syntax; must be 'jsonpath' or 'jq'")

if span_replacement.min_score is None and span_replacement.max_score is None:
raise DolmaConfigError(
"Either min_score or max_score must be specified for span_replacement"
)

# add min_score and max_score to the config if they are specified
min_score_config = (
{"min_score": span_replacement.min_score} if span_replacement.min_score is not None else {}
)
max_score_config = (
{"max_score": span_replacement.max_score} if span_replacement.max_score is not None else {}
)

# TODO: note that we are not using the syntax here yet; adding it later
stream_config_dict.setdefault("span_replacement", []).append(
{
"span": str(span_replacement.span),
"min_score": float(span_replacement.min_score),
"replacement": str(span_replacement.replacement),
"syntax": span_replacement.syntax,
**min_score_config,
**max_score_config,
}
)

Expand Down
257 changes: 249 additions & 8 deletions src/filters.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,252 @@
use std::io;

use crate::shard::shard_config::FilterConfig;
use crate::shard::shard_config::{FilterConfig, SpanReplacementConfig};
use jaq_interpret::{Ctx, Filter, FilterT, ParseCtx, RcIter, Val};
use jaq_std;
use jsonpath_rust::JsonPathFinder;
use serde_json::Value;

pub struct JqSelector {
pub selector: Filter,
}

impl JqSelector {
pub fn new(selector_string: &str) -> Result<JqSelector, io::Error> {
let mut defs = ParseCtx::new(Vec::new());
defs.insert_natives(jaq_core::core());
defs.insert_defs(jaq_std::std());
assert!(defs.errs.is_empty());

let (selector, errs) = jaq_parse::parse(selector_string, jaq_parse::main());
if !errs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"Error parsing '{:?}' into filter: {:?}",
selector_string, errs
),
));
}
match selector {
Some(selector) => {
let selector: jaq_interpret::Filter = defs.compile(selector);
if !defs.errs.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("Error compiling '{:?}' into filter.", selector_string),
));
}

Ok(JqSelector { selector: selector })
}
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("Parsing '{:?}' resulted in no filter", selector_string),
));
}
}
}

// select returns array of results if the filter matches multiple elements,
// or a single result if the filter matches a single element.
// in case of no match, it returns null
pub fn select(&self, json: &Value) -> Result<Value, io::Error> {
let inputs: RcIter<std::iter::Empty<_>> = RcIter::new(core::iter::empty());
let out: Vec<Result<jaq_interpret::Val, jaq_interpret::Error>> = self
.selector
.run((Ctx::new(Vec::new(), &inputs), Val::from(json.clone())))
.collect();
if out.is_empty() {
return Ok(Value::Null);
}
let mut result = Vec::new();
for resp in out {
match resp {
Ok(val) => result.push(val),
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("Error evaluating filter: {:?}", e),
))
}
}
}

match result.len() {
0 => Ok(Value::Null),
1 => Ok(Value::from(result[0].clone())),
_ => Ok(Value::from(result)),
}
}
}

pub struct JsonPathSelector {
pub path: String,
}

impl JsonPathSelector {
pub fn new(path: &str) -> Result<JsonPathSelector, io::Error> {
Ok(JsonPathSelector {
path: path.to_string(),
})
}

pub fn select(&self, json: &Value) -> Result<Value, io::Error> {
match JsonPathFinder::from_str("{}", &self.path) {
Ok(mut finder) => {
finder.set_json(Box::new(json.clone()));
match finder.find() {
Value::Array(arr) => match arr.len() {
0 => Ok(Value::Null),
1 => Ok(arr[0].clone()),
_ => Ok(Value::from(arr)),
},
Value::Null => Ok(Value::Null),
_ => Err(io::Error::new(
io::ErrorKind::Other,
format!("Error evaluating filter: {:?}", self.path),
)),
}
}
Err(e) => Err(io::Error::new(
io::ErrorKind::Other,
format!("Error evaluating filter: {:?}", e),
)),
}
}
}

pub enum Selector {
JqSelector(JqSelector),
JsonPathSelector(JsonPathSelector),
}

impl Selector {
pub fn new(selector_config: &SpanReplacementConfig) -> Result<Selector, io::Error> {
match selector_config.syntax.as_deref() {
Some("jq") => Ok(Selector::JqSelector(JqSelector::new(
&selector_config.span,
)?)),
Some("jsonpath") | None => Ok(Selector::JsonPathSelector(JsonPathSelector::new(
&selector_config.span,
)?)),
_ => Err(io::Error::new(
io::ErrorKind::Other,
format!("Unknown selector syntax: {:?}", selector_config.syntax),
)),
}
}

pub fn select(&self, json: &Value) -> Result<Value, io::Error> {
match self {
Selector::JqSelector(selector) => selector.select(json),
Selector::JsonPathSelector(selector) => selector.select(json),
}
}
}

#[cfg(test)]
pub mod selector_tests {
use super::*;
use serde_json::json;

#[test]
fn test_select() {
let doc = json!({
"attributes": {
"foo": "bar",
"baz": "qux"
}
});
let expected = json!("bar");

let jq_selector = JqSelector::new(".attributes.foo").unwrap();
assert_eq!(jq_selector.select(&doc).unwrap(), expected);

let jsonpath_selector = JsonPathSelector::new("$.attributes.foo").unwrap();
assert_eq!(jsonpath_selector.select(&doc).unwrap(), expected);
}

#[test]
fn test_select_array() {
let doc = json!({
"attributes": {
"foo": [1, 2, 3],
"baz": "qux"
}
});
let expected = json!([1, 2, 3]);

let jq_selector = JqSelector::new(".attributes.foo").unwrap();
assert_eq!(jq_selector.select(&doc).unwrap(), expected);

let jsonpath_selector = JsonPathSelector::new("$.attributes.foo").unwrap();
assert_eq!(jsonpath_selector.select(&doc).unwrap(), expected);
}

#[test]
fn test_select_object() {
let jq_selector = JqSelector::new(".attributes").unwrap();
let doc = json!({
"attributes": {
"foo": "bar",
"baz": "qux"
}
});
assert_eq!(
jq_selector.select(&doc).unwrap(),
json!({"foo": "bar", "baz": "qux"})
);
}

#[test]
fn test_select_null() {
let doc = json!({
"attributes": {
"baz": "qux"
}
});
let expected = json!(null);

let jq_selector = JqSelector::new(".attributes.foo").unwrap();
assert_eq!(jq_selector.select(&doc).unwrap(), expected);

let jsonpath_selector = JsonPathSelector::new("$.attributes.foo").unwrap();
assert_eq!(jsonpath_selector.select(&doc).unwrap(), expected);
}

#[test]
fn test_nested_select_null() {
let doc = json!({
"attributes": {
"not_foo": {
"baz": "qux"
}
}
});
let expected = json!(null);

let jq_selector = JqSelector::new(".attributes?.foo?.baz?").unwrap();
assert_eq!(jq_selector.select(&doc).unwrap(), expected);

let jsonpath_selector = JsonPathSelector::new("$.attributes.foo.baz").unwrap();
assert_eq!(jsonpath_selector.select(&doc).unwrap(), expected);
}

#[test]
fn test_select_error() {
let doc = json!({
"attributes": {
"foo": ["water", " & ", "bread"],
}
});

let jq_selector = JqSelector::new(".attributes.foo | add").unwrap();
assert_eq!(jq_selector.select(&doc).unwrap(), json!("water & bread"));
}
}

pub struct JqDocFilter {
pub include: Vec<Filter>,
pub exclude: Vec<Filter>,
Expand Down Expand Up @@ -157,7 +398,7 @@ impl DocFilter {
pub fn new(filter_config: Option<&FilterConfig>) -> Result<DocFilter, io::Error> {
match filter_config {
Some(filter_config) => match filter_config.syntax.as_deref() {
Some("jaq") => Ok(DocFilter::JqDocFilter(JqDocFilter::new(filter_config)?)),
Some("jq") => Ok(DocFilter::JqDocFilter(JqDocFilter::new(filter_config)?)),
Some("jsonpath") | None => Ok(DocFilter::JsonPathFilter(JsonPathFilter::new(
filter_config,
)?)),
Expand All @@ -179,7 +420,7 @@ impl DocFilter {
}

#[cfg(test)]
mod tests {
mod filter_tests {
use super::*;
use serde_json::json;

Expand All @@ -188,7 +429,7 @@ mod tests {
let filter_config = FilterConfig {
include: vec![".attributes.foo".to_string()],
exclude: vec![r#".attributes.baz == "quac""#.to_string()],
syntax: Some("jaq".to_string()),
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();
let doc = json!({
Expand All @@ -205,7 +446,7 @@ mod tests {
let filter_config = FilterConfig {
include: vec![".attributes.foo".to_string()],
exclude: vec![r#".attributes.baz == "qux""#.to_string()],
syntax: Some("jaq".to_string()),
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();
let doc = json!({
Expand All @@ -222,7 +463,7 @@ mod tests {
let filter_config = FilterConfig {
include: vec![".attributes.foo | length >= 3".to_string()],
exclude: vec![],
syntax: Some("jaq".to_string()),
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();
let doc = json!({
Expand Down Expand Up @@ -293,7 +534,7 @@ mod tests {
let filter_config = FilterConfig {
include: vec![".attributes.foo | add >= 6".to_string()],
exclude: vec![],
syntax: Some("jaq".to_string()),
syntax: Some("jq".to_string()),
};
let filters = DocFilter::new(Some(&filter_config)).unwrap();
let doc = json!({
Expand All @@ -318,7 +559,7 @@ mod tests {
let filter_config = FilterConfig {
include: vec![".x | sum".to_string()],
exclude: vec![],
syntax: Some("jaq".to_string()),
syntax: Some("jq".to_string()),
};

let result = DocFilter::new(Some(&filter_config));
Expand Down
Loading

0 comments on commit a75bf09

Please sign in to comment.