diff --git a/refinery/units/meta/dedup.py b/refinery/units/meta/dedup.py index 88b21e192..29eaebf62 100644 --- a/refinery/units/meta/dedup.py +++ b/refinery/units/meta/dedup.py @@ -1,35 +1,62 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from refinery.units import Unit, Arg +from refinery.lib.tools import isbuffer +from refinery.lib.meta import metavars +from refinery.lib.argformats import PythonExpression + +from hashlib import md5 class dedup(Unit): """ Deduplicates a sequence of multiple inputs. The deduplication is limited to the current `refinery.lib.frame`. """ - def __init__(self, count: Arg.Switch('-c', help='Store the count of each deduplicated chunk.') = False): - super().__init__(count=count) + def __init__( + self, + key: Arg('key', type=str, help='An optional meta variable expression to deduplicate.') = None, + count: Arg.Switch('-c', help='Store the count of each deduplicated chunk.') = False + ): + super().__init__(key=key, count=count) def filter(self, chunks): + keyvar = self.args.key + + if keyvar is not None: + def key(chunk): + v = PythonExpression.Evaluate(keyvar, metavars(chunk)) + if isbuffer(v): + v = md5(v).digest() + return v + else: + def key(chunk): + return md5(chunk).digest() + if self.args.count: - from collections import Counter - barrier = Counter(chunks) - for chunk in chunks: - if not chunk.visible: - yield chunk - continue - barrier.update(chunk) - for chunk, count in barrier.items(): - chunk.meta['count'] = count - yield chunk + counts = {} + buffer = {} + hashes = None else: - from hashlib import md5 - barrier = set() - for chunk in chunks: - if not chunk.visible: - yield chunk - continue - hashed = md5(chunk).digest() - if hashed not in barrier: - barrier.add(hashed) - yield chunk + hashes = set() + counts = None + buffer = None + + for chunk in chunks: + if not chunk.visible: + yield chunk + continue + + uid = key(chunk) + + if hashes is None: + counts[uid] = counts.get(uid, 0) + 1 + buffer.setdefault(uid, chunk) + elif uid in hashes: + continue + else: + hashes.add(uid) + yield chunk + + if hashes is None: + for uid, chunk in buffer.items(): + yield self.labelled(chunk, count=counts[uid]) diff --git a/test/units/meta/test_dedup.py b/test/units/meta/test_dedup.py index bd8c3932b..91b7190ec 100644 --- a/test/units/meta/test_dedup.py +++ b/test/units/meta/test_dedup.py @@ -39,3 +39,7 @@ def test_duplicated_strings(self): def test_count(self): pipeline = L('emit HELLO-WORLD [| push [| rex . | dedup -c | sorted count | pick :2 | pop t s ]| cfmt {t}{s} ]') self.assertEqual(pipeline(), B'LO') + + def test_key(self): + pipeline = L('emit FOO BAR BAZ BAMPF [| dedup size ]') + self.assertEqual(pipeline(), B'FOOBAMPF')