-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #36 from Bihaqo/develop
0.2.0
- Loading branch information
Showing
20 changed files
with
2,379 additions
and
423 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Change Log | ||
All notable changes to this project will be documented in this file. | ||
|
||
The format is based on [Keep a Changelog](http://keepachangelog.com/) | ||
and this project adheres to [Semantic Versioning](http://semver.org/). | ||
|
||
## [Unreleased] | ||
|
||
## [0.2.0] - 2017-03-23 | ||
### Added | ||
- (Partial) support for batches of TT-tensors. | ||
- Riemannian module (projection on the tangent space). | ||
- op property and str method for TensorTrain | ||
- concat_along_batch_dim | ||
- expand_batch_dim | ||
- gram_matrix | ||
- Multiplication by a number | ||
|
||
### Changed | ||
- Fix add function for dtypes not equal tf.float32 | ||
- flat_inner and quadratic_form now return numbers (instead of 1 x 1 tensors) | ||
|
||
## [0.1.0] - 2017-03-12 | ||
### Added | ||
- Indexing (e.g. TensorTrain[:, 3, 2:4]) | ||
- Full (converting TT to dense) | ||
- TT-SVD and rounding | ||
- Basic arithmetic (add, multiply, matmul, flat_inner) | ||
- Variables support | ||
- Kronecker module (functions for TT-rank 1 TT-matrices) | ||
- quadratic_form | ||
- frobenius_norm | ||
|
||
[Unreleased]: https://github.com/Bihaqo/t3f/compare/master...develop | ||
[0.2.0]: https://github.com/Bihaqo/t3f/compare/0.1.0...0.2.0 | ||
[0.1.0]: https://github.com/Bihaqo/t3f/compare/f24409508...0.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
from tensor_train import * | ||
from tensor_train_base import TensorTrainBase | ||
from tensor_train import TensorTrain | ||
from tensor_train_batch import TensorTrainBatch | ||
from variables import * | ||
from ops import * | ||
from batch_ops import * | ||
from initializers import * | ||
from regularizers import * | ||
from riemannian import * | ||
from shapes import * | ||
from decompositions import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import tensorflow as tf | ||
|
||
from tensor_train_base import TensorTrainBase | ||
from tensor_train_batch import TensorTrainBatch | ||
import ops | ||
|
||
|
||
def concat_along_batch_dim(tt_list): | ||
"""Concat all TensorTrainBatch objects along batch dimension. | ||
Args: | ||
tt_list: a list of TensorTrainBatch objects. | ||
Returns: | ||
TensorTrainBatch | ||
""" | ||
ndims = tt_list[0].ndims() | ||
|
||
if isinstance(tt_list, TensorTrainBase): | ||
# Not a list but just one element, nothing to concat. | ||
return tt_list | ||
|
||
for batch_idx in range(len(tt_list)): | ||
if not isinstance(tt_list[batch_idx], TensorTrainBatch): | ||
raise ValueError('All objects in the list should be TTBatch objects, got ' | ||
'%s' % tt_list[batch_idx]) | ||
for batch_idx in range(1, len(tt_list)): | ||
if tt_list[batch_idx].get_raw_shape() != tt_list[0].get_raw_shape(): | ||
raise ValueError('Shapes of all TT-batch objects should coincide, got %s ' | ||
'and %s' % (tt_list[0].get_raw_shape(), | ||
tt_list[batch_idx].get_raw_shape())) | ||
if tt_list[batch_idx].get_tt_ranks() != tt_list[0].get_tt_ranks(): | ||
raise ValueError('TT-ranks of all TT-batch objects should coincide, got ' | ||
'%s and %s' % (tt_list[0].get_tt_ranks(), | ||
tt_list[batch_idx].get_tt_ranks())) | ||
|
||
res_cores = [] | ||
for core_idx in range(ndims): | ||
curr_core = tf.concat([tt.tt_cores[core_idx] for tt in tt_list], axis=0) | ||
res_cores.append(curr_core) | ||
|
||
batch_size = sum([tt.batch_size for tt in tt_list]) | ||
|
||
return TensorTrainBatch(res_cores, tt_list[0].get_raw_shape(), | ||
tt_list[0].get_tt_ranks(), batch_size) | ||
|
||
|
||
def gram_matrix(tt_vectors, matrix=None): | ||
"""Computes Gramian matrix of a batch of TT-vecors. | ||
If matrix is None, computes | ||
res[i, j] = t3f.flat_inner(tt_vectors[i], tt_vectors[j]). | ||
If matrix is present, computes | ||
res[i, j] = t3f.flat_inner(tt_vectors[i], t3f.matmul(matrix, tt_vectors[j])) | ||
or more shorly | ||
res[i, j] = tt_vectors[i]^T * matrix * tt_vectors[j] | ||
Args: | ||
tt_vectors: TensorTrainBatch. | ||
matrix: None, or TensorTrain matrix. | ||
Returns: | ||
tf.tensor with the Gram matrix. | ||
""" | ||
ndims = tt_vectors.ndims() | ||
if matrix is None: | ||
curr_core = tt_vectors.tt_cores[0] | ||
res = tf.einsum('paijb,qcijd->pqbd', curr_core, curr_core) | ||
for core_idx in range(1, ndims): | ||
curr_core = tt_vectors.tt_cores[core_idx] | ||
res = tf.einsum('pqac,paijb,qcijd->pqbd', res, curr_core, curr_core) | ||
else: | ||
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j] | ||
vectors_shape = tt_vectors.get_shape() | ||
if vectors_shape[2] == 1 and vectors_shape[1] != 1: | ||
# TODO: not very efficient, better to use different order in einsum. | ||
tt_vectors = ops.transpose(tt_vectors) | ||
vectors_shape = tt_vectors.get_shape() | ||
if vectors_shape[1] != 1: | ||
# TODO: do something so that in case the shape is undefined on compilation | ||
# it still works. | ||
raise ValueError('The tt_vectors argument should be vectors (not ' | ||
'matrices) with shape defined on compilation.') | ||
curr_core = tt_vectors.tt_cores[0] | ||
curr_matrix_core = matrix.tt_cores[0] | ||
# We enumerate the dummy dimension (that takes 1 value) with `k`. | ||
res = tf.einsum('pakib,cijd,qekjf->pqbdf', curr_core, curr_matrix_core, | ||
curr_core) | ||
for core_idx in range(1, ndims): | ||
curr_core = tt_vectors.tt_cores[core_idx] | ||
curr_matrix_core = matrix.tt_cores[core_idx] | ||
res = tf.einsum('pqace,pakib,cijd,qekjf->pqbdf', res, curr_core, | ||
curr_matrix_core, curr_core) | ||
|
||
# Squeeze to make the result of size batch_size x batch_size instead of | ||
# batch_size x batch_size x 1 x 1. | ||
return tf.squeeze(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from tensor_train import TensorTrain | ||
from tensor_train_batch import TensorTrainBatch | ||
import ops | ||
import batch_ops | ||
import initializers | ||
|
||
|
||
class BatchOpsTest(tf.test.TestCase): | ||
|
||
def testConcatMatrix(self): | ||
# Test concating TTMatrix batches along batch dimension. | ||
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1) | ||
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4) | ||
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3) | ||
first_res = batch_ops.concat_along_batch_dim((first)) | ||
first_res = ops.full(first_res) | ||
first_second_res = batch_ops.concat_along_batch_dim((first, second)) | ||
first_second_res = ops.full(first_second_res) | ||
first_second_third_res = batch_ops.concat_along_batch_dim((first, second, | ||
third)) | ||
first_second_third_res = ops.full(first_second_third_res) | ||
|
||
first_full = ops.full(first) | ||
second_full = ops.full(second) | ||
third_full = ops.full(third) | ||
first_desired = first_full | ||
first_second_desired = tf.concat((first_full, second_full), axis=0) | ||
first_second_third_desired = tf.concat((first_full, second_full, third_full), | ||
axis=0) | ||
with self.test_session() as sess: | ||
res = sess.run((first_res, first_second_res, first_second_third_res, | ||
first_desired, first_second_desired, | ||
first_second_third_desired)) | ||
first_res_val = res[0] | ||
first_second_res_val = res[1] | ||
first_second_third_res_val = res[2] | ||
first_desired_val = res[3] | ||
first_second_desired_val = res[4] | ||
first_second_third_desired_val = res[5] | ||
self.assertAllClose(first_res_val, first_desired_val) | ||
self.assertAllClose(first_second_res_val, first_second_desired_val) | ||
self.assertAllClose(first_second_third_res_val, first_second_third_desired_val) | ||
|
||
def testGramMatrix(self): | ||
# Test Gram Matrix of a batch of TT vectors. | ||
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5) | ||
res_actual = batch_ops.gram_matrix(tt_vectors) | ||
full_vectors = tf.reshape(ops.full(tt_vectors), (5, 6)) | ||
res_desired = tf.matmul(full_vectors, tf.transpose(full_vectors)) | ||
res_desired = tf.squeeze(res_desired) | ||
with self.test_session() as sess: | ||
res_actual_val, res_desired_val = sess.run((res_actual, res_desired)) | ||
self.assertAllClose(res_desired_val, res_actual_val) | ||
|
||
def testGramMatrixWithMatrix(self): | ||
# Test Gram Matrix of a batch of TT vectors with providing a matrix, so we | ||
# should compute | ||
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j] | ||
tt_vectors = initializers.random_matrix_batch((None, (2, 3)), batch_size=4) | ||
matrix = initializers.random_matrix(((2, 3), (2, 3))) | ||
res_actual = batch_ops.gram_matrix(tt_vectors, matrix) | ||
full_vectors = tf.reshape(ops.full(tt_vectors), (4, 6)) | ||
with self.test_session() as sess: | ||
res = sess.run((res_actual, full_vectors, ops.full(matrix))) | ||
res_actual_val, vectors_val, matrix_val = res | ||
res_desired_val = np.zeros((4, 4)) | ||
for i in range(4): | ||
for j in range(4): | ||
curr_val = np.dot(vectors_val[i], matrix_val) | ||
curr_val = np.dot(curr_val, vectors_val[j]) | ||
res_desired_val[i, j] = curr_val | ||
self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5) | ||
|
||
if __name__ == "__main__": | ||
tf.test.main() | ||
|
Oops, something went wrong.