diff --git a/README.md b/README.md index 864523c..b2be117 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Tensor Puzzles -- by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp) (with Marco Treviso) +- by [Sasha Rush](http://rush-nlp.com) - [srush_nlp](https://twitter.com/srush_nlp) (with Marcos Treviso) @@ -40,7 +40,7 @@ tensor = torch.tensor ## Rules -1. These puzzles are about broadcasting. Know this rule. +1. These puzzles are about *broadcasting*. Know this rule. ![](https://pbs.twimg.com/media/FQywor0WYAssn7Y?format=png&name=large) @@ -57,8 +57,20 @@ def arange(i: int): return torch.tensor(range(i)) draw_examples("arange", [{"" : arange(i)} for i in [5, 3, 9]]) +``` + + + + + +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_7_0.svg) + + + +```python +# Example of broadcasting. examples = [(arange(4), arange(5)[:, None]) , (arange(3)[:, None], arange(2))] draw_examples("broadcast", [{"a": a, "b":b, "ret": a + b} for a, b in examples]) @@ -68,7 +80,7 @@ draw_examples("broadcast", [{"a": a, "b":b, "ret": a + b} for a, b in examples]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_7_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_8_0.svg) @@ -101,7 +113,7 @@ draw_examples("where", [{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_8_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_9_0.svg) @@ -120,7 +132,7 @@ test_ones = make_test("one", ones, ones_spec, add_sizes=["i"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_9_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_10_0.svg) @@ -149,7 +161,7 @@ test_sum = make_test("sum", sum, sum_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_12_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_13_0.svg) @@ -177,7 +189,7 @@ test_outer = make_test("outer", outer, outer_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_15_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_16_0.svg) @@ -205,7 +217,7 @@ test_diag = make_test("diag", diag, diag_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_18_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_19_0.svg) @@ -232,7 +244,7 @@ test_eye = make_test("eye", eye, eye_spec, add_sizes=["j"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_21_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_22_0.svg) @@ -264,7 +276,7 @@ test_triu = make_test("triu", triu, triu_spec, add_sizes=["j"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_24_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_25_0.svg) @@ -293,7 +305,7 @@ test_cumsum = make_test("cumsum", cumsum, cumsum_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_27_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_28_0.svg) @@ -321,7 +333,7 @@ test_diff = make_test("diff", diff, diff_spec, add_sizes=["i"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_30_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_31_0.svg) @@ -350,7 +362,7 @@ test_vstack = make_test("vstack", vstack, vstack_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_33_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_34_0.svg) @@ -381,7 +393,7 @@ test_roll = make_test("roll", roll, roll_spec, add_sizes=["i"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_36_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_37_0.svg) @@ -409,7 +421,7 @@ test_flip = make_test("flip", flip, flip_spec, add_sizes=["i"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_39_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_40_0.svg) @@ -441,7 +453,7 @@ test_compress = make_test("compress", compress, compress_spec, add_sizes=["i"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_42_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_43_0.svg) @@ -471,7 +483,7 @@ test_pad_to = make_test("pad_to", pad_to, pad_to_spec, add_sizes=["i", "j"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_45_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_46_0.svg) @@ -511,7 +523,7 @@ test_sequence = make_test("sequence_mask", -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_48_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_49_0.svg) @@ -546,7 +558,7 @@ test_bincount = make_test("bincount", -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_51_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_52_0.svg) @@ -581,7 +593,7 @@ test_scatter_add = make_test("scatter_add", -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_54_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_55_0.svg) @@ -611,7 +623,7 @@ test_flatten = make_test("flatten", flatten, flatten_spec, add_sizes=["i", "j"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_57_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_58_0.svg) @@ -638,7 +650,7 @@ test_linspace = make_test("linspace", linspace, linspace_spec, add_sizes=["n"]) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_60_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_61_0.svg) @@ -668,7 +680,7 @@ test_heaviside = make_test("heaviside", heaviside, heaviside_spec) -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_63_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_64_0.svg) @@ -706,7 +718,7 @@ test_repeat = make_test("repeat", repeat, repeat_spec, constraint=constraint_set -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_66_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_67_0.svg) @@ -742,7 +754,7 @@ test_bucketize = make_test("bucketize", bucketize, bucketize_spec, -![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_67_0.svg) +![svg](Tensor%20Puzzlers_files/Tensor%20Puzzlers_68_0.svg)