Skip to content

Commit

Permalink
make cache() and persist() useful. Version 0.2.13.
Browse files Browse the repository at this point in the history
  • Loading branch information
svenkreiss committed May 28, 2015
1 parent 696403e commit 9a9922a
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 14 deletions.
10 changes: 7 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ RDD
* ``aggregate(zeroValue, seqOp, combOp)``: aggregate value in partition with
seqOp and combine with combOp
* ``aggregateByKey(zeroValue, seqFunc, combFunc)``: aggregate by key
* ``cache()``: execute previous steps and cache result
* ``cache()``: synonym for ``persist()``
* ``cartesian(other)``: cartesian product
* ``coalesce()``: do nothing
* ``collect()``: return the underlying list
Expand Down Expand Up @@ -149,7 +149,7 @@ RDD
* ``mean()``: mean
* ``min()``: get the minimum element
* ``name()``: RDD's name
* ``persist()``: implemented as synonym for ``cache()``
* ``persist()``: caches outputs of previous operations (previous steps are still executed lazily)
* ``pipe(command)``: pipe the elements through an external command line tool
* ``reduce()``: reduce
* ``reduceByKey()``: reduce by key and return the new RDD
Expand Down Expand Up @@ -213,7 +213,11 @@ Infers ``.gz`` and ``.bz2`` compressions from the file name.
Changelog
=========

* `master <https://github.com/svenkreiss/pysparkling/compare/v0.2.10...master>`_
* `master <https://github.com/svenkreiss/pysparkling/compare/v0.2.13...master>`_
* `v0.2.13 <https://github.com/svenkreiss/pysparkling/compare/v0.2.10...v0.2.13>`_ (2015-05-28)
* make ``cache()`` and ``persist()`` do something useful
* logo
* fix ``foreach()``
* `v0.2.10 <https://github.com/svenkreiss/pysparkling/compare/v0.2.8...v0.2.10>`_ (2015-05-27)
* fix ``fileio.codec`` import
* support ``http://``
Expand Down
2 changes: 1 addition & 1 deletion pysparkling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""pysparkling module."""

__version__ = '0.2.12'
__version__ = '0.2.13'

from .exceptions import (FileAlreadyExistsException,
ConnectionException)
Expand Down
16 changes: 16 additions & 0 deletions pysparkling/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,19 @@ def __getstate__(self):
'index': self.index,
'_x': list(self.x())
}


class PersistedPartition(Partition):
def __init__(self, x, idx=None, storageLevel=None):
Partition.__init__(self, x, idx)
self.cache_x = None
self.storageLevel = storageLevel

def x(self):
if self.cache_x:
self.cache_x, r = itertools.tee(self.cache_x, 2)
return r
return Partition.x(self)

def set_cache_x(self, x):
self.cache_x = iter(list(x))
39 changes: 32 additions & 7 deletions pysparkling/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from . import utils
from .fileio import File
from .partition import PersistedPartition
from .exceptions import FileAlreadyExistsException

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,11 +73,7 @@ def combFuncByKey(l):
resultHandler=combFuncByKey)

def cache(self):
# This cache is not lazy, but it will guarantee that previous
# steps are only executed once.
for p in self.partitions():
p._x = list(p.x())
return self
return self.persist()

def cartesian(self, other):
v1 = self.collect()
Expand Down Expand Up @@ -275,7 +272,7 @@ def mapPartitions(self, f, preservesPartitioning=False):
return MapPartitionsRDD(
self,
lambda tc, i, x: f(x),
preservesPartitioning=True,
preservesPartitioning=preservesPartitioning,
)

def mapValues(self, f):
Expand Down Expand Up @@ -322,7 +319,8 @@ def name(self):
return self._name

def persist(self, storageLevel=None):
return self.cache()
"""[distributed]"""
return PersistedRDD(self, storageLevel=storageLevel)

def pipe(self, command, env={}):
return self.context.parallelize(subprocess.check_output(
Expand Down Expand Up @@ -504,3 +502,30 @@ def compute(self, split, task_context):

def partitions(self):
return self.prev.partitions()


class PersistedRDD(RDD):
def __init__(self, prev, storageLevel=None):
"""prev is the previous RDD.
"""
RDD.__init__(
self,
(
PersistedPartition(
p.x(),
p.index,
storageLevel,
) for p in prev.partitions()
),
prev.context,
)

self.prev = prev

def compute(self, split, task_context):
if split.cache_x is None:
split.set_cache_x(
self.prev.compute(split, task_context._create_child())
)
return split.x()
26 changes: 23 additions & 3 deletions tests/test_rdd_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,32 @@ def test_aggregate():
def test_aggregateByKey():
seqOp = (lambda x, y: x + y)
combOp = (lambda x, y: x + y)
r = Context().parallelize([('a', 1), ('b', 2), ('a', 3), ('c', 4)]).aggregateByKey(int, seqOp, combOp)
r = Context().parallelize(
[('a', 1), ('b', 2), ('a', 3), ('c', 4)]
).aggregateByKey(int, seqOp, combOp)
assert r['a'] == 4 and r['b'] == 2


def test_cache():
my_rdd = Context().parallelize([1, 2, 3, 4], 2)
my_rdd = my_rdd.map(lambda x: x*x).cache()
print('no exec until here')
print(my_rdd.first())
print('executed map on first partition only')
print(my_rdd.collect())
print('now map() was executed on all partitions and should '
'not be executed again')
print(my_rdd.collect())
assert len(my_rdd.collect()) == 4 and 16 in my_rdd.collect()


def test_cartesian():
rdd = Context().parallelize([1, 2])
r = sorted(rdd.cartesian(rdd).collect())
print(r)
assert r[0][0] == 1 and r[2][0] == 2 and len(r) == 4 and len(r[0]) == 2


def test_coalesce():
my_rdd = Context().parallelize([1, 2, 3], 2).coalesce(1)
assert my_rdd.getNumPartitions() == 1
Expand Down Expand Up @@ -61,7 +78,10 @@ def test_distinct():


def test_filter():
my_rdd = Context().parallelize([1, 2, 2, 4, 1, 3, 5, 9], 3).filter(lambda x: x % 2 == 0)
my_rdd = Context().parallelize(
[1, 2, 2, 4, 1, 3, 5, 9],
3,
).filter(lambda x: x % 2 == 0)
print(my_rdd.collect())
print(my_rdd.count())
assert my_rdd.count() == 3
Expand Down Expand Up @@ -254,4 +274,4 @@ def test_takeSample_partitions():

if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
test_sample()
test_cache()

0 comments on commit 9a9922a

Please sign in to comment.