-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from PySport/pattern-matching
Pattern matching
- Loading branch information
Showing
16 changed files
with
1,554 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# kloppy-query | ||
|
||
Video analysts spend a lot of time searching for interesting moments in the video. Probably a certain type of moments can be described by a pattern: pass, pass, shot, etc. In that case, can we automate the search? | ||
|
||
We might be able to do so. The kloppy library now provides a search mechanism based on regular expressions to search for patterns within event data. | ||
|
||
To make the use event simpler, kloppy comes with `kloppy-query`. This command line tool does all the heavy lifting for you and gives you a nice xml, ready for use in your favorite video analyse software. | ||
|
||
## Usage | ||
|
||
```shell script | ||
# grab some data from statsbomb open data project | ||
wget https://github.com/statsbomb/open-data/blob/master/data/events/15946.json?raw=true -O events.json | ||
wget https://raw.githubusercontent.com/statsbomb/open-data/master/data/lineups/15946.json -O lineup.json | ||
|
||
# run the query | ||
kloppy-query --input-statsbomb=events.json,lineup.json --query-file=ball_recovery.py --output-xml=ball_recovery.xml | ||
|
||
# check output | ||
cat ball_recovery.xml | ||
``` | ||
|
||
|
||
```xml | ||
<?xml version='1.0' encoding='utf-8'?> | ||
<file> | ||
<ALL_INSTANCES> | ||
<instance> | ||
<ID>0</ID> | ||
<code>away</code> | ||
<start>0.0</start> | ||
<end>16.15</end> | ||
</instance> | ||
<instance> | ||
<ID>1</ID> | ||
<code>home</code> | ||
<start>4.15</start> | ||
<end>29.687</end> | ||
</instance> | ||
<instance> | ||
<ID>2</ID> | ||
<code>away</code> | ||
<start>17.687</start> | ||
<end>71.228</end> | ||
</instance> | ||
<instance> | ||
<ID>3</ID> | ||
<code>home success</code> | ||
<start>59.227999999999994</start> | ||
<end>85.809</end> | ||
</instance> | ||
</ALL_INSTANCES> | ||
</file> | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from kloppy import event_pattern_matching as pm | ||
|
||
# This file can be consumed by kloppy-query command line like this: | ||
|
||
|
||
# kloppy-query --input-statsbomb=events.json,lineup.json --query-file=ball_recovery.py --output-xml=test.xml | ||
query = pm.Query( | ||
event_types=["pass", "shot"], | ||
pattern=( | ||
pm.match_pass( | ||
capture="last_pass_of_team_a" | ||
) + | ||
pm.match_pass( | ||
team=pm.not_same_as("last_pass_of_team_a.team") | ||
) * slice(1, None) + | ||
pm.group( | ||
pm.match_pass( | ||
success=True, | ||
team=pm.same_as("last_pass_of_team_a.team"), | ||
timestamp=pm.function( | ||
lambda timestamp, last_pass_of_team_a_timestamp: | ||
timestamp - last_pass_of_team_a_timestamp < 10 | ||
) | ||
) + ( | ||
pm.match_pass( | ||
success=True, | ||
team=pm.same_as("last_pass_of_team_a.team") | ||
) | | ||
pm.match_shot( | ||
team=pm.same_as("last_pass_of_team_a.team") | ||
) | ||
), | ||
capture="success" | ||
) * slice(0, 1) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .infra.serializers import * | ||
from .helpers import * | ||
from .infra import datasets | ||
from .domain.services.matchers.pattern import event as event_pattern_matching |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import argparse | ||
import sys | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
from xml.etree import ElementTree as ET | ||
|
||
from kloppy import load_statsbomb_event_data, event_pattern_matching as pm | ||
from kloppy.infra.utils import performance_logging | ||
|
||
sys.path.append(".") | ||
|
||
|
||
@dataclass | ||
class VideoFragment: | ||
id_: str | ||
label: str | ||
start: float | ||
end: float | ||
|
||
|
||
def write_to_xml(video_fragments: List[VideoFragment], filename): | ||
root = ET.Element("file") | ||
instances = ET.SubElement(root, "ALL_INSTANCES") | ||
for video_fragment in video_fragments: | ||
instance = ET.SubElement(instances, "instance") | ||
|
||
instance_id = ET.SubElement(instance, "ID") | ||
instance_id.text = video_fragment.id_ | ||
|
||
instance_code = ET.SubElement(instance, "code") | ||
instance_code.text = video_fragment.label | ||
|
||
instance_start = ET.SubElement(instance, "start") | ||
instance_start.text = str(max(0.0, video_fragment.start)) | ||
|
||
instance_end = ET.SubElement(instance, "end") | ||
instance_end.text = str(video_fragment.end) | ||
|
||
tree = ET.ElementTree(root) | ||
|
||
tree.write(filename, | ||
xml_declaration=True, | ||
encoding='utf-8', | ||
method="xml") | ||
|
||
|
||
def load_query(query_file: str) -> pm.Query: | ||
locals_dict = {} | ||
with open(query_file, "rb") as fp: | ||
exec(fp.read(), {}, locals_dict) | ||
|
||
if 'query' not in locals_dict: | ||
raise Exception("File does not contain query") | ||
return locals_dict['query'] | ||
|
||
|
||
def run_query(argv=sys.argv[1:]): | ||
parser = argparse.ArgumentParser(description="Run query on event data") | ||
parser.add_argument('--input-statsbomb', help="StatsBomb event input files (events.json,lineup.json)") | ||
parser.add_argument('--output-xml', help="Output file", required=True) | ||
parser.add_argument('--with-success', default=True, help="Input existence of success capture in output") | ||
parser.add_argument('--prepend-time', default=7, help="Seconds to prepend to match") | ||
parser.add_argument('--append-time', default=5, help="Seconds to append to match") | ||
parser.add_argument('--query-file', help="File containing the query", required=True) | ||
|
||
logger = logging.getLogger("run_query") | ||
logging.basicConfig(stream=sys.stdout, level=logging.INFO, | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | ||
|
||
opts = parser.parse_args(argv) | ||
|
||
query = load_query(opts.query_file) | ||
|
||
dataset = None | ||
if opts.input_statsbomb: | ||
with performance_logging("load dataset", logger=logger): | ||
events_filename, lineup_filename = opts.input_statsbomb.split(",") | ||
dataset = load_statsbomb_event_data( | ||
events_filename.strip(), | ||
lineup_filename.strip(), | ||
options={ | ||
"event_types": query.event_types | ||
} | ||
) | ||
|
||
if not dataset: | ||
raise Exception("You have to specify a dataset.") | ||
|
||
with performance_logging("searching", logger=logger): | ||
matches = pm.search(dataset, query.pattern) | ||
|
||
video_fragments = [] | ||
for i, match in enumerate(matches): | ||
success = 'success' in match.captures | ||
label = str(match.events[0].team) | ||
if opts.with_success and success: | ||
label += " success" | ||
|
||
start_timestamp = ( | ||
match.events[0].timestamp + | ||
match.events[0].period.start_timestamp - | ||
opts.prepend_time | ||
) | ||
end_timestamp = ( | ||
match.events[-1].timestamp + | ||
match.events[-1].period.start_timestamp + | ||
opts.append_time | ||
) | ||
|
||
video_fragments.append( | ||
VideoFragment( | ||
id_=str(i), | ||
start=start_timestamp, | ||
end=end_timestamp, | ||
label=label | ||
) | ||
) | ||
|
||
if opts.output_xml: | ||
write_to_xml(video_fragments, opts.output_xml) | ||
logger.info(f"Wrote {len(video_fragments)} video fragments to file") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Callable, Tuple, Dict, List, Iterator | ||
|
||
from kloppy.domain import ( | ||
EventDataset, | ||
PassEvent, ShotEvent, CarryEvent, TakeOnEvent, Event | ||
) | ||
from .regexp import * | ||
from .regexp import _make_match, _TrailItem | ||
|
||
|
||
class WithCaptureMatcher(Matcher): | ||
def __init__(self, matcher: Callable[[Tok, Dict[str, List[Tok]]], bool]): | ||
self.matcher = matcher | ||
|
||
def match(self, token: Tok, trail: Tuple[_TrailItem[Out], ...]) -> Iterator[Out]: | ||
match = _make_match(trail) | ||
captures = { | ||
name: capture[0].trail | ||
for name, capture in match.children.items() | ||
} | ||
if self.matcher(token, captures): | ||
yield token | ||
|
||
|
||
def match_generic(event_cls, capture=None, **kwargs): | ||
def _matcher_fn(event: Event, captures: Dict[str, List[Event]]) -> bool: | ||
if not isinstance(event, event_cls): | ||
return False | ||
|
||
# TODO: v[0] points to first record | ||
captures = { | ||
k: v[0] for k, v in captures.items() | ||
} | ||
for attr_name, attr_value in kwargs.items(): | ||
if callable(attr_value): | ||
attr_real_value = getattr(event, attr_name) | ||
result = attr_value(attr_name, attr_real_value, captures) | ||
else: | ||
if attr_name == "success": | ||
result = event.result and event.result.is_success | ||
else: | ||
result = getattr(event, attr_name) == attr_value | ||
if not result: | ||
return False | ||
return True | ||
|
||
_matcher = Final(WithCaptureMatcher(matcher=_matcher_fn)) | ||
|
||
if capture: | ||
return _matcher[capture] | ||
else: | ||
return _matcher | ||
|
||
|
||
match_pass = partial(match_generic, PassEvent) | ||
match_shot = partial(match_generic, ShotEvent) | ||
match_carry = partial(match_generic, CarryEvent) | ||
match_take_on = partial(match_generic, TakeOnEvent) | ||
|
||
|
||
def same_as(capture: str): | ||
capture_name, attribute_name = capture.split(".") | ||
|
||
def fn(attr_name, value, captures): | ||
return value == getattr(captures[capture_name], attribute_name) | ||
return fn | ||
|
||
|
||
def not_same_as(capture: str): | ||
capture_name, capture_attribute_name = capture.split(".") | ||
|
||
def fn(attr_name, value, captures): | ||
return value != getattr(captures[capture_name], capture_attribute_name) | ||
return fn | ||
|
||
|
||
def group(node, capture=None): | ||
if capture: | ||
return node[capture] | ||
return node | ||
|
||
|
||
def function(fn): | ||
def wrapper(attr_name, value, captures): | ||
capture_values = { | ||
f"{capture_name}_{attr_name}": getattr(capture_value, attr_name) | ||
for capture_name, capture_value in captures.items() | ||
if capture_value | ||
} | ||
return fn(value, **capture_values) | ||
return wrapper | ||
|
||
|
||
@dataclass | ||
class Match: | ||
events: List[Event] | ||
captures: Dict[str, List[Event]] | ||
|
||
|
||
def search(dataset: EventDataset, pattern: Node[Tok, Out]): | ||
events = dataset.events | ||
re = RegExp.from_ast(pattern) | ||
|
||
results = [] | ||
i = 0 | ||
c = len(events) | ||
while i < c: | ||
matches = re.match(events[i:], consume_all=False) | ||
if matches: | ||
results.append( | ||
Match( | ||
events=matches[0].trail, | ||
# TODO: check trail[0] because this points to the first event in the capture and not | ||
# all of them | ||
captures={ | ||
capture_name: capture_value[0].trail[0] | ||
for capture_name, capture_value | ||
in matches[0].children.items() | ||
} | ||
) | ||
) | ||
i += 1 | ||
|
||
return results | ||
|
||
|
||
@dataclass | ||
class Query: | ||
event_types: List[str] | ||
pattern: Node[Tok, Out] | ||
|
||
|
||
__all__ = [ | ||
"search", | ||
"match_pass", "match_carry", "match_take_on", "match_shot", | ||
"same_as", "not_same_as", "function", "Query" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .ast import * | ||
from .matchers import * | ||
from .regexp import * |
Oops, something went wrong.