From b8e1ab82d9203c7c461028bac13bb457ce2c9fd0 Mon Sep 17 00:00:00 2001 From: Liam Keegan Date: Fri, 17 Jan 2025 10:46:01 +0100 Subject: [PATCH 1/2] add `index_bits` option to support large datasets - add `index_bits` optional argument to KDTree - default value is 32: preserves existing behaviour & performance - user can specify 64 instead to use 64-bit integers - this ensures correct results when n_points * k > 2^32 - uses approx 50% more RAM than 32 bit option - resolves #38 - in 32-bit int mode add checks to avoid returning incorrect results - KDTree checks that `n_points < 2^32` - KDTree.query checks that `n_points * k < 2^32` - update tests - parametrize all existing tests to test 32-bit and 64-bit int mode - add a query test with n_points * k too large - didn't add a test with n_points too large as it would require 16GB RAM to run --- pykdtree/_kdtree_core.c | 1892 ++++++++++++++++++++++++++++------ pykdtree/_kdtree_core.c.mako | 340 +++--- pykdtree/kdtree.pyx | 188 +++- pykdtree/test_tree.py | 97 +- 4 files changed, 1949 insertions(+), 568 deletions(-) diff --git a/pykdtree/_kdtree_core.c b/pykdtree/_kdtree_core.c index 2d6862b..bfcc90d 100644 --- a/pykdtree/_kdtree_core.c +++ b/pykdtree/_kdtree_core.c @@ -29,7 +29,13 @@ Anne M. Archibald and libANN by David M. Mount and Sunil Arya. #include #define PA(i,d) (pa[no_dims * pidx[i] + d]) -#define PASWAP(a,b) { uint32_t tmp = pidx[a]; pidx[a] = pidx[b]; pidx[b] = tmp; } +#define PASWAP_int32_t(a,b) { uint32_t tmp = pidx[a]; pidx[a] = pidx[b]; pidx[b] = tmp; } +#define PASWAP_int64_t(a,b) { uint64_t tmp = pidx[a]; pidx[a] = pidx[b]; pidx[b] = tmp; } + +#define IDX_MAX_int32_t UINT32_MAX +#define IDX_MAX_int64_t UINT64_MAX +#define DIST_MAX_float FLT_MAX +#define DIST_MAX_double DBL_MAX #ifdef _MSC_VER #define restrict __restrict @@ -44,17 +50,38 @@ typedef struct uint32_t n; float cut_bounds_lv; float cut_bounds_hv; - struct Node_float *left_child; - struct Node_float *right_child; -} Node_float; + struct Node_float_int32_t *left_child; + struct Node_float_int32_t *right_child; +} Node_float_int32_t; typedef struct { float *bbox; int8_t no_dims; uint32_t *pidx; - struct Node_float *root; -} Tree_float; + struct Node_float_int32_t *root; +} Tree_float_int32_t; + + +typedef struct +{ + float cut_val; + int8_t cut_dim; + uint64_t start_idx; + uint64_t n; + float cut_bounds_lv; + float cut_bounds_hv; + struct Node_float_int64_t *left_child; + struct Node_float_int64_t *right_child; +} Node_float_int64_t; + +typedef struct +{ + float *bbox; + int8_t no_dims; + uint64_t *pidx; + struct Node_float_int64_t *root; +} Tree_float_int64_t; typedef struct @@ -65,67 +92,1424 @@ typedef struct uint32_t n; double cut_bounds_lv; double cut_bounds_hv; - struct Node_double *left_child; - struct Node_double *right_child; -} Node_double; + struct Node_double_int32_t *left_child; + struct Node_double_int32_t *right_child; +} Node_double_int32_t; typedef struct { double *bbox; int8_t no_dims; uint32_t *pidx; - struct Node_double *root; -} Tree_double; + struct Node_double_int32_t *root; +} Tree_double_int32_t; + + +typedef struct +{ + double cut_val; + int8_t cut_dim; + uint64_t start_idx; + uint64_t n; + double cut_bounds_lv; + double cut_bounds_hv; + struct Node_double_int64_t *left_child; + struct Node_double_int64_t *right_child; +} Node_double_int64_t; + +typedef struct +{ + double *bbox; + int8_t no_dims; + uint64_t *pidx; + struct Node_double_int64_t *root; +} Tree_double_int64_t; + + + +float calc_dist_float(float *point1_coord, float *point2_coord, int8_t no_dims); +float get_cube_offset_float(int8_t dim, float *point_coord, float *bbox); +float get_min_dist_float(float *point_coord, int8_t no_dims, float *bbox); + + +void insert_point_float_int32_t(uint32_t *closest_idx, float *closest_dist, uint32_t pidx, float cur_dist, uint32_t k); +void get_bounding_box_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, float *bbox); +int partition_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *bbox, int8_t *cut_dim, + float *cut_val, uint32_t *n_lo); +Tree_float_int32_t* construct_tree_float_int32_t(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp); +Node_float_int32_t* construct_subtree_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, float *bbox); +Node_float_int32_t * create_node_float_int32_t(uint32_t start_idx, uint32_t n, int is_leaf); +void delete_subtree_float_int32_t(Node_float_int32_t *root); +void delete_tree_float_int32_t(Tree_float_int32_t *tree); +void print_tree_float_int32_t(Node_float_int32_t *root, int level); +void search_leaf_float_int32_t(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, + uint32_t k, uint32_t *restrict closest_idx, float *restrict closest_dist); +void search_leaf_float_int32_t_mask(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, + uint32_t k, uint8_t *restrict mask, uint32_t *restrict closest_idx, float *restrict closest_dist); +void search_splitnode_float_int32_t(Node_float_int32_t *root, float *pa, uint32_t *pidx, int8_t no_dims, float *point_coord, + float min_dist, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint32_t * closest_idx, float *closest_dist); +void search_tree_float_int32_t(Tree_float_int32_t *tree, float *pa, float *point_coords, + uint32_t num_points, uint32_t k, float distance_upper_bound, + float eps, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists); + + +void insert_point_float_int64_t(uint64_t *closest_idx, float *closest_dist, uint64_t pidx, float cur_dist, uint64_t k); +void get_bounding_box_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t n, float *bbox); +int partition_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *bbox, int8_t *cut_dim, + float *cut_val, uint64_t *n_lo); +Tree_float_int64_t* construct_tree_float_int64_t(float *pa, int8_t no_dims, uint64_t n, uint64_t bsp); +Node_float_int64_t* construct_subtree_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, uint64_t bsp, float *bbox); +Node_float_int64_t * create_node_float_int64_t(uint64_t start_idx, uint64_t n, int is_leaf); +void delete_subtree_float_int64_t(Node_float_int64_t *root); +void delete_tree_float_int64_t(Tree_float_int64_t *tree); +void print_tree_float_int64_t(Node_float_int64_t *root, int level); +void search_leaf_float_int64_t(float *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *restrict point_coord, + uint64_t k, uint64_t *restrict closest_idx, float *restrict closest_dist); +void search_leaf_float_int64_t_mask(float *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *restrict point_coord, + uint64_t k, uint8_t *restrict mask, uint64_t *restrict closest_idx, float *restrict closest_dist); +void search_splitnode_float_int64_t(Node_float_int64_t *root, float *pa, uint64_t *pidx, int8_t no_dims, float *point_coord, + float min_dist, uint64_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint64_t * closest_idx, float *closest_dist); +void search_tree_float_int64_t(Tree_float_int64_t *tree, float *pa, float *point_coords, + uint64_t num_points, uint64_t k, float distance_upper_bound, + float eps, uint8_t *mask, uint64_t *closest_idxs, float *closest_dists); + + +double calc_dist_double(double *point1_coord, double *point2_coord, int8_t no_dims); +double get_cube_offset_double(int8_t dim, double *point_coord, double *bbox); +double get_min_dist_double(double *point_coord, int8_t no_dims, double *bbox); + + +void insert_point_double_int32_t(uint32_t *closest_idx, double *closest_dist, uint32_t pidx, double cur_dist, uint32_t k); +void get_bounding_box_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, double *bbox); +int partition_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *bbox, int8_t *cut_dim, + double *cut_val, uint32_t *n_lo); +Tree_double_int32_t* construct_tree_double_int32_t(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp); +Node_double_int32_t* construct_subtree_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, double *bbox); +Node_double_int32_t * create_node_double_int32_t(uint32_t start_idx, uint32_t n, int is_leaf); +void delete_subtree_double_int32_t(Node_double_int32_t *root); +void delete_tree_double_int32_t(Tree_double_int32_t *tree); +void print_tree_double_int32_t(Node_double_int32_t *root, int level); +void search_leaf_double_int32_t(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, + uint32_t k, uint32_t *restrict closest_idx, double *restrict closest_dist); +void search_leaf_double_int32_t_mask(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, + uint32_t k, uint8_t *restrict mask, uint32_t *restrict closest_idx, double *restrict closest_dist); +void search_splitnode_double_int32_t(Node_double_int32_t *root, double *pa, uint32_t *pidx, int8_t no_dims, double *point_coord, + double min_dist, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint32_t * closest_idx, double *closest_dist); +void search_tree_double_int32_t(Tree_double_int32_t *tree, double *pa, double *point_coords, + uint32_t num_points, uint32_t k, double distance_upper_bound, + double eps, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists); + + +void insert_point_double_int64_t(uint64_t *closest_idx, double *closest_dist, uint64_t pidx, double cur_dist, uint64_t k); +void get_bounding_box_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t n, double *bbox); +int partition_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *bbox, int8_t *cut_dim, + double *cut_val, uint64_t *n_lo); +Tree_double_int64_t* construct_tree_double_int64_t(double *pa, int8_t no_dims, uint64_t n, uint64_t bsp); +Node_double_int64_t* construct_subtree_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, uint64_t bsp, double *bbox); +Node_double_int64_t * create_node_double_int64_t(uint64_t start_idx, uint64_t n, int is_leaf); +void delete_subtree_double_int64_t(Node_double_int64_t *root); +void delete_tree_double_int64_t(Tree_double_int64_t *tree); +void print_tree_double_int64_t(Node_double_int64_t *root, int level); +void search_leaf_double_int64_t(double *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *restrict point_coord, + uint64_t k, uint64_t *restrict closest_idx, double *restrict closest_dist); +void search_leaf_double_int64_t_mask(double *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *restrict point_coord, + uint64_t k, uint8_t *restrict mask, uint64_t *restrict closest_idx, double *restrict closest_dist); +void search_splitnode_double_int64_t(Node_double_int64_t *root, double *pa, uint64_t *pidx, int8_t no_dims, double *point_coord, + double min_dist, uint64_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint64_t * closest_idx, double *closest_dist); +void search_tree_double_int64_t(Tree_double_int64_t *tree, double *pa, double *point_coords, + uint64_t num_points, uint64_t k, double distance_upper_bound, + double eps, uint8_t *mask, uint64_t *closest_idxs, double *closest_dists); + + + +/************************************************ +Calculate squared cartesian distance between points +Params: + point1_coord : point 1 + point2_coord : point 2 +************************************************/ +float calc_dist_float(float *point1_coord, float *point2_coord, int8_t no_dims) +{ + /* Calculate squared distance */ + float dist = 0, dim_dist; + int8_t i; + for (i = 0; i < no_dims; i++) + { + dim_dist = point2_coord[i] - point1_coord[i]; + dist += dim_dist * dim_dist; + } + return dist; +} + +/************************************************ +Get squared distance from point to cube in specified dimension +Params: + dim : dimension + point_coord : cartesian coordinates of point + bbox : cube +************************************************/ +float get_cube_offset_float(int8_t dim, float *point_coord, float *bbox) +{ + float dim_coord = point_coord[dim]; + + if (dim_coord < bbox[2 * dim]) + { + /* Left of cube in dimension */ + return dim_coord - bbox[2 * dim]; + } + else if (dim_coord > bbox[2 * dim + 1]) + { + /* Right of cube in dimension */ + return dim_coord - bbox[2 * dim + 1]; + } + else + { + /* Inside cube in dimension */ + return 0.; + } +} + +/************************************************ +Get minimum squared distance between point and cube. +Params: + point_coord : cartesian coordinates of point + no_dims : number of dimensions + bbox : cube +************************************************/ +float get_min_dist_float(float *point_coord, int8_t no_dims, float *bbox) +{ + float cube_offset = 0, cube_offset_dim; + int8_t i; + + for (i = 0; i < no_dims; i++) + { + cube_offset_dim = get_cube_offset_float(i, point_coord, bbox); + cube_offset += cube_offset_dim * cube_offset_dim; + } + + return cube_offset; +} +/************************************************ +Insert point into priority queue +Params: + closest_idx : index queue + closest_dist : distance queue + pidx : permutation index of data points + cur_dist : distance to point inserted + k : number of neighbours +************************************************/ +void insert_point_float_int32_t(uint32_t *closest_idx, float *closest_dist, uint32_t pidx, float cur_dist, uint32_t k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + if (closest_dist[i - 1] > cur_dist) + { + closest_dist[i] = closest_dist[i - 1]; + closest_idx[i] = closest_idx[i - 1]; + } + else + { + break; + } + } + closest_idx[i] = pidx; + closest_dist[i] = cur_dist; +} + +/************************************************ +Get the bounding box of a set of points +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + n : number of points + bbox : bounding box (return) +************************************************/ +void get_bounding_box_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, float *bbox) +{ + float cur; + int8_t i, j; + uint32_t bbox_idx, i2; + + /* Use first data point to initialize */ + for (i = 0; i < no_dims; i++) + { + bbox[2 * i] = bbox[2 * i + 1] = PA(0, i); + } + + /* Update using rest of data points */ + for (i2 = 1; i2 < n; i2++) + { + for (j = 0; j < no_dims; j++) + { + bbox_idx = 2 * j; + cur = PA(i2, j); + if (cur < bbox[bbox_idx]) + { + bbox[bbox_idx] = cur; + } + else if (cur > bbox[bbox_idx + 1]) + { + bbox[bbox_idx + 1] = cur; + } + } + } +} + +/************************************************ +Partition a range of data points by manipulation the permutation index. +The sliding midpoint rule is used for the partitioning. +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + start_idx : index of first data point to use + n : number of data points + bbox : bounding box of data points + cut_dim : dimension used for partition (return) + cut_val : value of cutting point (return) + n_lo : number of point below cutting plane (return) +************************************************/ +int partition_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *bbox, int8_t *cut_dim, float *cut_val, uint32_t *n_lo) +{ + int8_t dim = 0, i; + uint32_t p, q, i2; + float size = 0, min_val, max_val, split, side_len, cur_val; + uint32_t end_idx = start_idx + n - 1; + + /* Find largest bounding box side */ + for (i = 0; i < no_dims; i++) + { + side_len = bbox[2 * i + 1] - bbox[2 * i]; + if (side_len > size) + { + dim = i; + size = side_len; + } + } + + min_val = bbox[2 * dim]; + max_val = bbox[2 * dim + 1]; + + /* Check for zero length or inconsistent */ + if (min_val >= max_val) + return 1; + + /* Use middle for splitting */ + split = (min_val + max_val) / 2; + + /* Partition all data points around middle */ + p = start_idx; + q = end_idx; + while (p <= q) + { + if (PA(p, dim) < split) + { + p++; + } + else if (PA(q, dim) >= split) + { + /* Guard for underflow */ + if (q > 0) + { + q--; + } + else + { + break; + } + } + else + { + PASWAP_int32_t(p, q); + p++; + q--; + } + } + + /* Check for empty splits */ + if (p == start_idx) + { + /* No points less than split. + Split at lowest point instead. + Minimum 1 point will be in lower box. + */ + + uint32_t j = start_idx; + split = PA(j, dim); + for (i2 = start_idx + 1; i2 <= end_idx; i2++) + { + /* Find lowest point */ + cur_val = PA(i2, dim); + if (cur_val < split) + { + j = i2; + split = cur_val; + } + } + PASWAP_int32_t(j, start_idx); + p = start_idx + 1; + } + else if (p == end_idx + 1) + { + /* No points greater than split. + Split at highest point instead. + Minimum 1 point will be in higher box. + */ + + uint32_t j = end_idx; + split = PA(j, dim); + for (i2 = start_idx; i2 < end_idx; i2++) + { + /* Find highest point */ + cur_val = PA(i2, dim); + if (cur_val > split) + { + j = i2; + split = cur_val; + } + } + PASWAP_int32_t(j, end_idx); + p = end_idx; + } + + /* Set return values */ + *cut_dim = dim; + *cut_val = split; + *n_lo = p - start_idx; + return 0; +} + +/************************************************ +Construct a sub tree over a range of data points. +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + start_idx : index of first data point to use + n : number of data points + bsp : number of points per leaf + bbox : bounding box of set of data points +************************************************/ +Node_float_int32_t* construct_subtree_float_int32_t(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, float *bbox) +{ + /* Create new node */ + int is_leaf = (n <= bsp); + Node_float_int32_t *root = create_node_float_int32_t(start_idx, n, is_leaf); + int rval; + int8_t cut_dim; + uint32_t n_lo; + float cut_val, lv, hv; + if (is_leaf) + { + /* Make leaf node */ + root->cut_dim = -1; + } + else + { + /* Make split node */ + /* Partition data set and set node info */ + rval = partition_float_int32_t(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); + if (rval == 1) + { + root->cut_dim = -1; + return root; + } + root->cut_val = cut_val; + root->cut_dim = cut_dim; + + /* Recurse on both subsets */ + lv = bbox[2 * cut_dim]; + hv = bbox[2 * cut_dim + 1]; + + /* Set bounds for cut dimension */ + root->cut_bounds_lv = lv; + root->cut_bounds_hv = hv; + + /* Update bounding box before call to lower subset and restore after */ + bbox[2 * cut_dim + 1] = cut_val; + root->left_child = (struct Node_float_int32_t *)construct_subtree_float_int32_t(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); + bbox[2 * cut_dim + 1] = hv; + + /* Update bounding box before call to higher subset and restore after */ + bbox[2 * cut_dim] = cut_val; + root->right_child = (struct Node_float_int32_t *)construct_subtree_float_int32_t(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); + bbox[2 * cut_dim] = lv; + } + return root; +} + +/************************************************ +Construct a tree over data points. +Params: + pa : data points + no_dims: number of dimensions + n : number of data points + bsp : number of points per leaf +************************************************/ +Tree_float_int32_t* construct_tree_float_int32_t(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp) +{ + Tree_float_int32_t *tree = (Tree_float_int32_t *)malloc(sizeof(Tree_float_int32_t)); + uint32_t i; + uint32_t *pidx; + float *bbox; + + tree->no_dims = no_dims; + + /* Initialize permutation array */ + pidx = (uint32_t *)malloc(sizeof(uint32_t) * n); + for (i = 0; i < n; i++) + { + pidx[i] = i; + } + + bbox = (float *)malloc(2 * sizeof(float) * no_dims); + get_bounding_box_float_int32_t(pa, pidx, no_dims, n, bbox); + tree->bbox = bbox; + + /* Construct subtree on full dataset */ + tree->root = (struct Node_float_int32_t *)construct_subtree_float_int32_t(pa, pidx, no_dims, 0, n, bsp, bbox); + + tree->pidx = pidx; + return tree; +} + +/************************************************ +Create a tree node. +Params: + start_idx : index of first data point to use + n : number of data points +************************************************/ +Node_float_int32_t* create_node_float_int32_t(uint32_t start_idx, uint32_t n, int is_leaf) +{ + Node_float_int32_t *new_node; + if (is_leaf) + { + /* + Allocate only the part of the struct that will be used in a leaf node. + This relies on the C99 specification of struct layout conservation and padding and + that dereferencing is never attempted for the node pointers in a leaf. + */ + new_node = (Node_float_int32_t *)malloc(sizeof(Node_float_int32_t) - 2 * sizeof(Node_float_int32_t *)); + } + else + { + new_node = (Node_float_int32_t *)malloc(sizeof(Node_float_int32_t)); + } + new_node->n = n; + new_node->start_idx = start_idx; + return new_node; +} + +/************************************************ +Delete subtree +Params: + root : root node of subtree to delete +************************************************/ +void delete_subtree_float_int32_t(Node_float_int32_t *root) +{ + if (root->cut_dim != -1) + { + delete_subtree_float_int32_t((Node_float_int32_t *)root->left_child); + delete_subtree_float_int32_t((Node_float_int32_t *)root->right_child); + } + free(root); +} + +/************************************************ +Delete tree +Params: + tree : Tree struct of kd tree +************************************************/ +void delete_tree_float_int32_t(Tree_float_int32_t *tree) +{ + delete_subtree_float_int32_t((Node_float_int32_t *)tree->root); + free(tree->bbox); + free(tree->pidx); + free(tree); +} + +/************************************************ +Print +************************************************/ +void print_tree_float_int32_t(Node_float_int32_t *root, int level) +{ + int i; + for (i = 0; i < level; i++) + { + printf(" "); + } + printf("(cut_val: %f, cut_dim: %i)\n", root->cut_val, root->cut_dim); + if (root->cut_dim != -1) + print_tree_float_int32_t((Node_float_int32_t *)root->left_child, level + 1); + if (root->cut_dim != -1) + print_tree_float_int32_t((Node_float_int32_t *)root->right_child, level + 1); +} + +/************************************************ +Search a leaf node for closest point +Params: + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + start_idx : index of first data point to use + size : number of data points + point_coord : query point + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_leaf_float_int32_t(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, + uint32_t k, uint32_t *restrict closest_idx, float *restrict closest_dist) +{ + float cur_dist; + uint32_t i; + /* Loop through all points in leaf */ + for (i = 0; i < n; i++) + { + /* Get distance to query point */ + cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + /* Update closest info if new point is closest so far*/ + if (cur_dist < closest_dist[k - 1]) + { + insert_point_float_int32_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + } + } +} + + +/************************************************ +Search a leaf node for closest point with data point mask +Params: + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + start_idx : index of first data point to use + size : number of data points + point_coord : query point + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_leaf_float_int32_t_mask(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, + uint32_t k, uint8_t *mask, uint32_t *restrict closest_idx, float *restrict closest_dist) +{ + float cur_dist; + uint32_t i; + /* Loop through all points in leaf */ + for (i = 0; i < n; i++) + { + /* Is this point masked out? */ + if (mask[pidx[start_idx + i]]) + { + continue; + } + /* Get distance to query point */ + cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + /* Update closest info if new point is closest so far*/ + if (cur_dist < closest_dist[k - 1]) + { + insert_point_float_int32_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + } + } +} + +/************************************************ +Search subtree for nearest to query point +Params: + root : root node of subtree + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + point_coord : query point + min_dist : minumum distance to nearest neighbour + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_splitnode_float_int32_t(Node_float_int32_t *root, float *pa, uint32_t *pidx, int8_t no_dims, float *point_coord, + float min_dist, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, + uint32_t *closest_idx, float *closest_dist) +{ + int8_t dim; + float dist_left, dist_right; + float new_offset; + float box_diff; + + /* Skip if distance bound exeeded */ + if (min_dist > distance_upper_bound) + { + return; + } + + dim = root->cut_dim; + + /* Handle leaf node */ + if (dim == -1) + { + if (mask) + { + search_leaf_float_int32_t_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); + } + else + { + search_leaf_float_int32_t(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); + } + return; + } + + /* Get distance to cutting plane */ + new_offset = point_coord[dim] - root->cut_val; + + if (new_offset < 0) + { + /* Left of cutting plane */ + dist_left = min_dist; + if (dist_left < closest_dist[k - 1] * eps_fac) + { + /* Search left subtree if minimum distance is below limit */ + search_splitnode_float_int32_t((Node_float_int32_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + + /* Right of cutting plane. Update minimum distance. + See Algorithms for Fast Vector Quantization + Sunil Arya and David M. Mount. */ + box_diff = root->cut_bounds_lv - point_coord[dim]; + if (box_diff < 0) + { + box_diff = 0; + } + dist_right = min_dist - box_diff * box_diff + new_offset * new_offset; + if (dist_right < closest_dist[k - 1] * eps_fac) + { + /* Search right subtree if minimum distance is below limit*/ + search_splitnode_float_int32_t((Node_float_int32_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + } + else + { + /* Right of cutting plane */ + dist_right = min_dist; + if (dist_right < closest_dist[k - 1] * eps_fac) + { + /* Search right subtree if minimum distance is below limit*/ + search_splitnode_float_int32_t((Node_float_int32_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + + /* Left of cutting plane. Update minimum distance. + See Algorithms for Fast Vector Quantization + Sunil Arya and David M. Mount. */ + box_diff = point_coord[dim] - root->cut_bounds_hv; + if (box_diff < 0) + { + box_diff = 0; + } + dist_left = min_dist - box_diff * box_diff + new_offset * new_offset; + if (dist_left < closest_dist[k - 1] * eps_fac) + { + /* Search left subtree if minimum distance is below limit*/ + search_splitnode_float_int32_t((Node_float_int32_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + } +} + +/************************************************ +Search for nearest neighbour for a set of query points +Params: + tree : Tree struct of kd tree + pa : data points + pidx : permutation index of data points + point_coords : query points + num_points : number of query points + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_tree_float_int32_t(Tree_float_int32_t *tree, float *pa, float *point_coords, + uint32_t num_points, uint32_t k, float distance_upper_bound, + float eps, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists) +{ + float min_dist; + float eps_fac = 1 / ((1 + eps) * (1 + eps)); + int8_t no_dims = tree->no_dims; + float *bbox = tree->bbox; + uint32_t *pidx = tree->pidx; + uint32_t j = 0; +#if defined(_MSC_VER) && defined(_OPENMP) + int32_t i = 0; + int32_t local_num_points = (int32_t) num_points; +#else + uint32_t i; + uint32_t local_num_points = num_points; +#endif + Node_float_int32_t *root = (Node_float_int32_t *)tree->root; + + /* Queries are OpenMP enabled */ + #pragma omp parallel + { + /* The low chunk size is important to avoid L2 cache trashing + for spatial coherent query datasets + */ + #pragma omp for private(i, j) schedule(static, 100) nowait + for (i = 0; i < local_num_points; i++) + { + for (j = 0; j < k; j++) + { + closest_idxs[i * k + j] = IDX_MAX_int32_t; + closest_dists[i * k + j] = DIST_MAX_float; + } + min_dist = get_min_dist_float(point_coords + no_dims * i, no_dims, bbox); + search_splitnode_float_int32_t(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, + k, distance_upper_bound, eps_fac, mask, &closest_idxs[i * k], &closest_dists[i * k]); + } + } +} + +/************************************************ +Insert point into priority queue +Params: + closest_idx : index queue + closest_dist : distance queue + pidx : permutation index of data points + cur_dist : distance to point inserted + k : number of neighbours +************************************************/ +void insert_point_float_int64_t(uint64_t *closest_idx, float *closest_dist, uint64_t pidx, float cur_dist, uint64_t k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + if (closest_dist[i - 1] > cur_dist) + { + closest_dist[i] = closest_dist[i - 1]; + closest_idx[i] = closest_idx[i - 1]; + } + else + { + break; + } + } + closest_idx[i] = pidx; + closest_dist[i] = cur_dist; +} + +/************************************************ +Get the bounding box of a set of points +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + n : number of points + bbox : bounding box (return) +************************************************/ +void get_bounding_box_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t n, float *bbox) +{ + float cur; + int8_t i, j; + uint64_t bbox_idx, i2; + + /* Use first data point to initialize */ + for (i = 0; i < no_dims; i++) + { + bbox[2 * i] = bbox[2 * i + 1] = PA(0, i); + } + + /* Update using rest of data points */ + for (i2 = 1; i2 < n; i2++) + { + for (j = 0; j < no_dims; j++) + { + bbox_idx = 2 * j; + cur = PA(i2, j); + if (cur < bbox[bbox_idx]) + { + bbox[bbox_idx] = cur; + } + else if (cur > bbox[bbox_idx + 1]) + { + bbox[bbox_idx + 1] = cur; + } + } + } +} + +/************************************************ +Partition a range of data points by manipulation the permutation index. +The sliding midpoint rule is used for the partitioning. +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + start_idx : index of first data point to use + n : number of data points + bbox : bounding box of data points + cut_dim : dimension used for partition (return) + cut_val : value of cutting point (return) + n_lo : number of point below cutting plane (return) +************************************************/ +int partition_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *bbox, int8_t *cut_dim, float *cut_val, uint64_t *n_lo) +{ + int8_t dim = 0, i; + uint64_t p, q, i2; + float size = 0, min_val, max_val, split, side_len, cur_val; + uint64_t end_idx = start_idx + n - 1; + + /* Find largest bounding box side */ + for (i = 0; i < no_dims; i++) + { + side_len = bbox[2 * i + 1] - bbox[2 * i]; + if (side_len > size) + { + dim = i; + size = side_len; + } + } + + min_val = bbox[2 * dim]; + max_val = bbox[2 * dim + 1]; + + /* Check for zero length or inconsistent */ + if (min_val >= max_val) + return 1; + + /* Use middle for splitting */ + split = (min_val + max_val) / 2; + + /* Partition all data points around middle */ + p = start_idx; + q = end_idx; + while (p <= q) + { + if (PA(p, dim) < split) + { + p++; + } + else if (PA(q, dim) >= split) + { + /* Guard for underflow */ + if (q > 0) + { + q--; + } + else + { + break; + } + } + else + { + PASWAP_int64_t(p, q); + p++; + q--; + } + } + + /* Check for empty splits */ + if (p == start_idx) + { + /* No points less than split. + Split at lowest point instead. + Minimum 1 point will be in lower box. + */ + + uint64_t j = start_idx; + split = PA(j, dim); + for (i2 = start_idx + 1; i2 <= end_idx; i2++) + { + /* Find lowest point */ + cur_val = PA(i2, dim); + if (cur_val < split) + { + j = i2; + split = cur_val; + } + } + PASWAP_int64_t(j, start_idx); + p = start_idx + 1; + } + else if (p == end_idx + 1) + { + /* No points greater than split. + Split at highest point instead. + Minimum 1 point will be in higher box. + */ + + uint64_t j = end_idx; + split = PA(j, dim); + for (i2 = start_idx; i2 < end_idx; i2++) + { + /* Find highest point */ + cur_val = PA(i2, dim); + if (cur_val > split) + { + j = i2; + split = cur_val; + } + } + PASWAP_int64_t(j, end_idx); + p = end_idx; + } + + /* Set return values */ + *cut_dim = dim; + *cut_val = split; + *n_lo = p - start_idx; + return 0; +} + +/************************************************ +Construct a sub tree over a range of data points. +Params: + pa : data points + pidx : permutation index of data points + no_dims: number of dimensions + start_idx : index of first data point to use + n : number of data points + bsp : number of points per leaf + bbox : bounding box of set of data points +************************************************/ +Node_float_int64_t* construct_subtree_float_int64_t(float *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, uint64_t bsp, float *bbox) +{ + /* Create new node */ + int is_leaf = (n <= bsp); + Node_float_int64_t *root = create_node_float_int64_t(start_idx, n, is_leaf); + int rval; + int8_t cut_dim; + uint64_t n_lo; + float cut_val, lv, hv; + if (is_leaf) + { + /* Make leaf node */ + root->cut_dim = -1; + } + else + { + /* Make split node */ + /* Partition data set and set node info */ + rval = partition_float_int64_t(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); + if (rval == 1) + { + root->cut_dim = -1; + return root; + } + root->cut_val = cut_val; + root->cut_dim = cut_dim; + + /* Recurse on both subsets */ + lv = bbox[2 * cut_dim]; + hv = bbox[2 * cut_dim + 1]; + + /* Set bounds for cut dimension */ + root->cut_bounds_lv = lv; + root->cut_bounds_hv = hv; + + /* Update bounding box before call to lower subset and restore after */ + bbox[2 * cut_dim + 1] = cut_val; + root->left_child = (struct Node_float_int64_t *)construct_subtree_float_int64_t(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); + bbox[2 * cut_dim + 1] = hv; + + /* Update bounding box before call to higher subset and restore after */ + bbox[2 * cut_dim] = cut_val; + root->right_child = (struct Node_float_int64_t *)construct_subtree_float_int64_t(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); + bbox[2 * cut_dim] = lv; + } + return root; +} + +/************************************************ +Construct a tree over data points. +Params: + pa : data points + no_dims: number of dimensions + n : number of data points + bsp : number of points per leaf +************************************************/ +Tree_float_int64_t* construct_tree_float_int64_t(float *pa, int8_t no_dims, uint64_t n, uint64_t bsp) +{ + Tree_float_int64_t *tree = (Tree_float_int64_t *)malloc(sizeof(Tree_float_int64_t)); + uint64_t i; + uint64_t *pidx; + float *bbox; + + tree->no_dims = no_dims; + + /* Initialize permutation array */ + pidx = (uint64_t *)malloc(sizeof(uint64_t) * n); + for (i = 0; i < n; i++) + { + pidx[i] = i; + } + + bbox = (float *)malloc(2 * sizeof(float) * no_dims); + get_bounding_box_float_int64_t(pa, pidx, no_dims, n, bbox); + tree->bbox = bbox; + + /* Construct subtree on full dataset */ + tree->root = (struct Node_float_int64_t *)construct_subtree_float_int64_t(pa, pidx, no_dims, 0, n, bsp, bbox); + + tree->pidx = pidx; + return tree; +} + +/************************************************ +Create a tree node. +Params: + start_idx : index of first data point to use + n : number of data points +************************************************/ +Node_float_int64_t* create_node_float_int64_t(uint64_t start_idx, uint64_t n, int is_leaf) +{ + Node_float_int64_t *new_node; + if (is_leaf) + { + /* + Allocate only the part of the struct that will be used in a leaf node. + This relies on the C99 specification of struct layout conservation and padding and + that dereferencing is never attempted for the node pointers in a leaf. + */ + new_node = (Node_float_int64_t *)malloc(sizeof(Node_float_int64_t) - 2 * sizeof(Node_float_int64_t *)); + } + else + { + new_node = (Node_float_int64_t *)malloc(sizeof(Node_float_int64_t)); + } + new_node->n = n; + new_node->start_idx = start_idx; + return new_node; +} + +/************************************************ +Delete subtree +Params: + root : root node of subtree to delete +************************************************/ +void delete_subtree_float_int64_t(Node_float_int64_t *root) +{ + if (root->cut_dim != -1) + { + delete_subtree_float_int64_t((Node_float_int64_t *)root->left_child); + delete_subtree_float_int64_t((Node_float_int64_t *)root->right_child); + } + free(root); +} + +/************************************************ +Delete tree +Params: + tree : Tree struct of kd tree +************************************************/ +void delete_tree_float_int64_t(Tree_float_int64_t *tree) +{ + delete_subtree_float_int64_t((Node_float_int64_t *)tree->root); + free(tree->bbox); + free(tree->pidx); + free(tree); +} + +/************************************************ +Print +************************************************/ +void print_tree_float_int64_t(Node_float_int64_t *root, int level) +{ + int i; + for (i = 0; i < level; i++) + { + printf(" "); + } + printf("(cut_val: %f, cut_dim: %i)\n", root->cut_val, root->cut_dim); + if (root->cut_dim != -1) + print_tree_float_int64_t((Node_float_int64_t *)root->left_child, level + 1); + if (root->cut_dim != -1) + print_tree_float_int64_t((Node_float_int64_t *)root->right_child, level + 1); +} + +/************************************************ +Search a leaf node for closest point +Params: + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + start_idx : index of first data point to use + size : number of data points + point_coord : query point + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_leaf_float_int64_t(float *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *restrict point_coord, + uint64_t k, uint64_t *restrict closest_idx, float *restrict closest_dist) +{ + float cur_dist; + uint64_t i; + /* Loop through all points in leaf */ + for (i = 0; i < n; i++) + { + /* Get distance to query point */ + cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + /* Update closest info if new point is closest so far*/ + if (cur_dist < closest_dist[k - 1]) + { + insert_point_float_int64_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + } + } +} + + +/************************************************ +Search a leaf node for closest point with data point mask +Params: + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + start_idx : index of first data point to use + size : number of data points + point_coord : query point + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_leaf_float_int64_t_mask(float *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, float *restrict point_coord, + uint64_t k, uint8_t *mask, uint64_t *restrict closest_idx, float *restrict closest_dist) +{ + float cur_dist; + uint64_t i; + /* Loop through all points in leaf */ + for (i = 0; i < n; i++) + { + /* Is this point masked out? */ + if (mask[pidx[start_idx + i]]) + { + continue; + } + /* Get distance to query point */ + cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + /* Update closest info if new point is closest so far*/ + if (cur_dist < closest_dist[k - 1]) + { + insert_point_float_int64_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + } + } +} + +/************************************************ +Search subtree for nearest to query point +Params: + root : root node of subtree + pa : data points + pidx : permutation index of data points + no_dims : number of dimensions + point_coord : query point + min_dist : minumum distance to nearest neighbour + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_splitnode_float_int64_t(Node_float_int64_t *root, float *pa, uint64_t *pidx, int8_t no_dims, float *point_coord, + float min_dist, uint64_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, + uint64_t *closest_idx, float *closest_dist) +{ + int8_t dim; + float dist_left, dist_right; + float new_offset; + float box_diff; + + /* Skip if distance bound exeeded */ + if (min_dist > distance_upper_bound) + { + return; + } + + dim = root->cut_dim; + + /* Handle leaf node */ + if (dim == -1) + { + if (mask) + { + search_leaf_float_int64_t_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); + } + else + { + search_leaf_float_int64_t(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); + } + return; + } + + /* Get distance to cutting plane */ + new_offset = point_coord[dim] - root->cut_val; + + if (new_offset < 0) + { + /* Left of cutting plane */ + dist_left = min_dist; + if (dist_left < closest_dist[k - 1] * eps_fac) + { + /* Search left subtree if minimum distance is below limit */ + search_splitnode_float_int64_t((Node_float_int64_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + + /* Right of cutting plane. Update minimum distance. + See Algorithms for Fast Vector Quantization + Sunil Arya and David M. Mount. */ + box_diff = root->cut_bounds_lv - point_coord[dim]; + if (box_diff < 0) + { + box_diff = 0; + } + dist_right = min_dist - box_diff * box_diff + new_offset * new_offset; + if (dist_right < closest_dist[k - 1] * eps_fac) + { + /* Search right subtree if minimum distance is below limit*/ + search_splitnode_float_int64_t((Node_float_int64_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + } + else + { + /* Right of cutting plane */ + dist_right = min_dist; + if (dist_right < closest_dist[k - 1] * eps_fac) + { + /* Search right subtree if minimum distance is below limit*/ + search_splitnode_float_int64_t((Node_float_int64_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + + /* Left of cutting plane. Update minimum distance. + See Algorithms for Fast Vector Quantization + Sunil Arya and David M. Mount. */ + box_diff = point_coord[dim] - root->cut_bounds_hv; + if (box_diff < 0) + { + box_diff = 0; + } + dist_left = min_dist - box_diff * box_diff + new_offset * new_offset; + if (dist_left < closest_dist[k - 1] * eps_fac) + { + /* Search left subtree if minimum distance is below limit*/ + search_splitnode_float_int64_t((Node_float_int64_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + } + } +} + +/************************************************ +Search for nearest neighbour for a set of query points +Params: + tree : Tree struct of kd tree + pa : data points + pidx : permutation index of data points + point_coords : query points + num_points : number of query points + mask : boolean array of invalid (True) and valid (False) data points + closest_idx : index of closest data point found (return) + closest_dist : distance to closest point (return) +************************************************/ +void search_tree_float_int64_t(Tree_float_int64_t *tree, float *pa, float *point_coords, + uint64_t num_points, uint64_t k, float distance_upper_bound, + float eps, uint8_t *mask, uint64_t *closest_idxs, float *closest_dists) +{ + float min_dist; + float eps_fac = 1 / ((1 + eps) * (1 + eps)); + int8_t no_dims = tree->no_dims; + float *bbox = tree->bbox; + uint64_t *pidx = tree->pidx; + uint64_t j = 0; +#if defined(_MSC_VER) && defined(_OPENMP) + int64_t i = 0; + int64_t local_num_points = (int64_t) num_points; +#else + uint64_t i; + uint64_t local_num_points = num_points; +#endif + Node_float_int64_t *root = (Node_float_int64_t *)tree->root; + + /* Queries are OpenMP enabled */ + #pragma omp parallel + { + /* The low chunk size is important to avoid L2 cache trashing + for spatial coherent query datasets + */ + #pragma omp for private(i, j) schedule(static, 100) nowait + for (i = 0; i < local_num_points; i++) + { + for (j = 0; j < k; j++) + { + closest_idxs[i * k + j] = IDX_MAX_int64_t; + closest_dists[i * k + j] = DIST_MAX_float; + } + min_dist = get_min_dist_float(point_coords + no_dims * i, no_dims, bbox); + search_splitnode_float_int64_t(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, + k, distance_upper_bound, eps_fac, mask, &closest_idxs[i * k], &closest_dists[i * k]); + } + } +} + +/************************************************ +Calculate squared cartesian distance between points +Params: + point1_coord : point 1 + point2_coord : point 2 +************************************************/ +double calc_dist_double(double *point1_coord, double *point2_coord, int8_t no_dims) +{ + /* Calculate squared distance */ + double dist = 0, dim_dist; + int8_t i; + for (i = 0; i < no_dims; i++) + { + dim_dist = point2_coord[i] - point1_coord[i]; + dist += dim_dist * dim_dist; + } + return dist; +} + +/************************************************ +Get squared distance from point to cube in specified dimension +Params: + dim : dimension + point_coord : cartesian coordinates of point + bbox : cube +************************************************/ +double get_cube_offset_double(int8_t dim, double *point_coord, double *bbox) +{ + double dim_coord = point_coord[dim]; -void insert_point_float(uint32_t *closest_idx, float *closest_dist, uint32_t pidx, float cur_dist, uint32_t k); -void get_bounding_box_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, float *bbox); -int partition_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *bbox, int8_t *cut_dim, - float *cut_val, uint32_t *n_lo); -Tree_float* construct_tree_float(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp); -Node_float* construct_subtree_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, float *bbox); -Node_float * create_node_float(uint32_t start_idx, uint32_t n, int is_leaf); -void delete_subtree_float(Node_float *root); -void delete_tree_float(Tree_float *tree); -void print_tree_float(Node_float *root, int level); -float calc_dist_float(float *point1_coord, float *point2_coord, int8_t no_dims); -float get_cube_offset_float(int8_t dim, float *point_coord, float *bbox); -float get_min_dist_float(float *point_coord, int8_t no_dims, float *bbox); -void search_leaf_float(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, float *restrict closest_dist); -void search_leaf_float_mask(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, - uint32_t k, uint8_t *restrict mask, uint32_t *restrict closest_idx, float *restrict closest_dist); -void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t no_dims, float *point_coord, - float min_dist, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint32_t * closest_idx, float *closest_dist); -void search_tree_float(Tree_float *tree, float *pa, float *point_coords, - uint32_t num_points, uint32_t k, float distance_upper_bound, - float eps, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists); + if (dim_coord < bbox[2 * dim]) + { + /* Left of cube in dimension */ + return dim_coord - bbox[2 * dim]; + } + else if (dim_coord > bbox[2 * dim + 1]) + { + /* Right of cube in dimension */ + return dim_coord - bbox[2 * dim + 1]; + } + else + { + /* Inside cube in dimension */ + return 0.; + } +} +/************************************************ +Get minimum squared distance between point and cube. +Params: + point_coord : cartesian coordinates of point + no_dims : number of dimensions + bbox : cube +************************************************/ +double get_min_dist_double(double *point_coord, int8_t no_dims, double *bbox) +{ + double cube_offset = 0, cube_offset_dim; + int8_t i; -void insert_point_double(uint32_t *closest_idx, double *closest_dist, uint32_t pidx, double cur_dist, uint32_t k); -void get_bounding_box_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, double *bbox); -int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *bbox, int8_t *cut_dim, - double *cut_val, uint32_t *n_lo); -Tree_double* construct_tree_double(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp); -Node_double* construct_subtree_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, double *bbox); -Node_double * create_node_double(uint32_t start_idx, uint32_t n, int is_leaf); -void delete_subtree_double(Node_double *root); -void delete_tree_double(Tree_double *tree); -void print_tree_double(Node_double *root, int level); -double calc_dist_double(double *point1_coord, double *point2_coord, int8_t no_dims); -double get_cube_offset_double(int8_t dim, double *point_coord, double *bbox); -double get_min_dist_double(double *point_coord, int8_t no_dims, double *bbox); -void search_leaf_double(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, double *restrict closest_dist); -void search_leaf_double_mask(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, - uint32_t k, uint8_t *restrict mask, uint32_t *restrict closest_idx, double *restrict closest_dist); -void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8_t no_dims, double *point_coord, - double min_dist, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint32_t * closest_idx, double *closest_dist); -void search_tree_double(Tree_double *tree, double *pa, double *point_coords, - uint32_t num_points, uint32_t k, double distance_upper_bound, - double eps, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists); + for (i = 0; i < no_dims; i++) + { + cube_offset_dim = get_cube_offset_double(i, point_coord, bbox); + cube_offset += cube_offset_dim * cube_offset_dim; + } + return cube_offset; +} /************************************************ @@ -137,7 +1521,7 @@ Insert point into priority queue cur_dist : distance to point inserted k : number of neighbours ************************************************/ -void insert_point_float(uint32_t *closest_idx, float *closest_dist, uint32_t pidx, float cur_dist, uint32_t k) +void insert_point_double_int32_t(uint32_t *closest_idx, double *closest_dist, uint32_t pidx, double cur_dist, uint32_t k) { int i; for (i = k - 1; i > 0; i--) @@ -165,9 +1549,9 @@ Get the bounding box of a set of points n : number of points bbox : bounding box (return) ************************************************/ -void get_bounding_box_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, float *bbox) +void get_bounding_box_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, double *bbox) { - float cur; + double cur; int8_t i, j; uint32_t bbox_idx, i2; @@ -210,11 +1594,11 @@ The sliding midpoint rule is used for the partitioning. cut_val : value of cutting point (return) n_lo : number of point below cutting plane (return) ************************************************/ -int partition_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *bbox, int8_t *cut_dim, float *cut_val, uint32_t *n_lo) +int partition_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *bbox, int8_t *cut_dim, double *cut_val, uint32_t *n_lo) { int8_t dim = 0, i; uint32_t p, q, i2; - float size = 0, min_val, max_val, split, side_len, cur_val; + double size = 0, min_val, max_val, split, side_len, cur_val; uint32_t end_idx = start_idx + n - 1; /* Find largest bounding box side */ @@ -261,7 +1645,7 @@ int partition_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_id } else { - PASWAP(p, q); + PASWAP_int32_t(p, q); p++; q--; } @@ -287,7 +1671,7 @@ int partition_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_id split = cur_val; } } - PASWAP(j, start_idx); + PASWAP_int32_t(j, start_idx); p = start_idx + 1; } else if (p == end_idx + 1) @@ -309,7 +1693,7 @@ int partition_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_id split = cur_val; } } - PASWAP(j, end_idx); + PASWAP_int32_t(j, end_idx); p = end_idx; } @@ -331,15 +1715,15 @@ Construct a sub tree over a range of data points. bsp : number of points per leaf bbox : bounding box of set of data points ************************************************/ -Node_float* construct_subtree_float(float *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, float *bbox) +Node_double_int32_t* construct_subtree_double_int32_t(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, double *bbox) { /* Create new node */ int is_leaf = (n <= bsp); - Node_float *root = create_node_float(start_idx, n, is_leaf); + Node_double_int32_t *root = create_node_double_int32_t(start_idx, n, is_leaf); int rval; int8_t cut_dim; uint32_t n_lo; - float cut_val, lv, hv; + double cut_val, lv, hv; if (is_leaf) { /* Make leaf node */ @@ -349,7 +1733,7 @@ Node_float* construct_subtree_float(float *pa, uint32_t *pidx, int8_t no_dims, u { /* Make split node */ /* Partition data set and set node info */ - rval = partition_float(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); + rval = partition_double_int32_t(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); if (rval == 1) { root->cut_dim = -1; @@ -368,12 +1752,12 @@ Node_float* construct_subtree_float(float *pa, uint32_t *pidx, int8_t no_dims, u /* Update bounding box before call to lower subset and restore after */ bbox[2 * cut_dim + 1] = cut_val; - root->left_child = (struct Node_float *)construct_subtree_float(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); + root->left_child = (struct Node_double_int32_t *)construct_subtree_double_int32_t(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); bbox[2 * cut_dim + 1] = hv; /* Update bounding box before call to higher subset and restore after */ bbox[2 * cut_dim] = cut_val; - root->right_child = (struct Node_float *)construct_subtree_float(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); + root->right_child = (struct Node_double_int32_t *)construct_subtree_double_int32_t(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); bbox[2 * cut_dim] = lv; } return root; @@ -387,12 +1771,12 @@ Construct a tree over data points. n : number of data points bsp : number of points per leaf ************************************************/ -Tree_float* construct_tree_float(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp) +Tree_double_int32_t* construct_tree_double_int32_t(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp) { - Tree_float *tree = (Tree_float *)malloc(sizeof(Tree_float)); + Tree_double_int32_t *tree = (Tree_double_int32_t *)malloc(sizeof(Tree_double_int32_t)); uint32_t i; uint32_t *pidx; - float *bbox; + double *bbox; tree->no_dims = no_dims; @@ -403,12 +1787,12 @@ Tree_float* construct_tree_float(float *pa, int8_t no_dims, uint32_t n, uint32_t pidx[i] = i; } - bbox = (float *)malloc(2 * sizeof(float) * no_dims); - get_bounding_box_float(pa, pidx, no_dims, n, bbox); + bbox = (double *)malloc(2 * sizeof(double) * no_dims); + get_bounding_box_double_int32_t(pa, pidx, no_dims, n, bbox); tree->bbox = bbox; /* Construct subtree on full dataset */ - tree->root = (struct Node_float *)construct_subtree_float(pa, pidx, no_dims, 0, n, bsp, bbox); + tree->root = (struct Node_double_int32_t *)construct_subtree_double_int32_t(pa, pidx, no_dims, 0, n, bsp, bbox); tree->pidx = pidx; return tree; @@ -420,9 +1804,9 @@ Create a tree node. start_idx : index of first data point to use n : number of data points ************************************************/ -Node_float* create_node_float(uint32_t start_idx, uint32_t n, int is_leaf) +Node_double_int32_t* create_node_double_int32_t(uint32_t start_idx, uint32_t n, int is_leaf) { - Node_float *new_node; + Node_double_int32_t *new_node; if (is_leaf) { /* @@ -430,11 +1814,11 @@ Node_float* create_node_float(uint32_t start_idx, uint32_t n, int is_leaf) This relies on the C99 specification of struct layout conservation and padding and that dereferencing is never attempted for the node pointers in a leaf. */ - new_node = (Node_float *)malloc(sizeof(Node_float) - 2 * sizeof(Node_float *)); + new_node = (Node_double_int32_t *)malloc(sizeof(Node_double_int32_t) - 2 * sizeof(Node_double_int32_t *)); } else { - new_node = (Node_float *)malloc(sizeof(Node_float)); + new_node = (Node_double_int32_t *)malloc(sizeof(Node_double_int32_t)); } new_node->n = n; new_node->start_idx = start_idx; @@ -446,12 +1830,12 @@ Delete subtree Params: root : root node of subtree to delete ************************************************/ -void delete_subtree_float(Node_float *root) +void delete_subtree_double_int32_t(Node_double_int32_t *root) { if (root->cut_dim != -1) { - delete_subtree_float((Node_float *)root->left_child); - delete_subtree_float((Node_float *)root->right_child); + delete_subtree_double_int32_t((Node_double_int32_t *)root->left_child); + delete_subtree_double_int32_t((Node_double_int32_t *)root->right_child); } free(root); } @@ -461,9 +1845,9 @@ Delete tree Params: tree : Tree struct of kd tree ************************************************/ -void delete_tree_float(Tree_float *tree) +void delete_tree_double_int32_t(Tree_double_int32_t *tree) { - delete_subtree_float((Node_float *)tree->root); + delete_subtree_double_int32_t((Node_double_int32_t *)tree->root); free(tree->bbox); free(tree->pidx); free(tree); @@ -472,7 +1856,7 @@ void delete_tree_float(Tree_float *tree) /************************************************ Print ************************************************/ -void print_tree_float(Node_float *root, int level) +void print_tree_double_int32_t(Node_double_int32_t *root, int level) { int i; for (i = 0; i < level; i++) @@ -481,77 +1865,9 @@ void print_tree_float(Node_float *root, int level) } printf("(cut_val: %f, cut_dim: %i)\n", root->cut_val, root->cut_dim); if (root->cut_dim != -1) - print_tree_float((Node_float *)root->left_child, level + 1); + print_tree_double_int32_t((Node_double_int32_t *)root->left_child, level + 1); if (root->cut_dim != -1) - print_tree_float((Node_float *)root->right_child, level + 1); -} - -/************************************************ -Calculate squared cartesian distance between points -Params: - point1_coord : point 1 - point2_coord : point 2 -************************************************/ -float calc_dist_float(float *point1_coord, float *point2_coord, int8_t no_dims) -{ - /* Calculate squared distance */ - float dist = 0, dim_dist; - int8_t i; - for (i = 0; i < no_dims; i++) - { - dim_dist = point2_coord[i] - point1_coord[i]; - dist += dim_dist * dim_dist; - } - return dist; -} - -/************************************************ -Get squared distance from point to cube in specified dimension -Params: - dim : dimension - point_coord : cartesian coordinates of point - bbox : cube -************************************************/ -float get_cube_offset_float(int8_t dim, float *point_coord, float *bbox) -{ - float dim_coord = point_coord[dim]; - - if (dim_coord < bbox[2 * dim]) - { - /* Left of cube in dimension */ - return dim_coord - bbox[2 * dim]; - } - else if (dim_coord > bbox[2 * dim + 1]) - { - /* Right of cube in dimension */ - return dim_coord - bbox[2 * dim + 1]; - } - else - { - /* Inside cube in dimension */ - return 0.; - } -} - -/************************************************ -Get minimum squared distance between point and cube. -Params: - point_coord : cartesian coordinates of point - no_dims : number of dimensions - bbox : cube -************************************************/ -float get_min_dist_float(float *point_coord, int8_t no_dims, float *bbox) -{ - float cube_offset = 0, cube_offset_dim; - int8_t i; - - for (i = 0; i < no_dims; i++) - { - cube_offset_dim = get_cube_offset_float(i, point_coord, bbox); - cube_offset += cube_offset_dim * cube_offset_dim; - } - - return cube_offset; + print_tree_double_int32_t((Node_double_int32_t *)root->right_child, level + 1); } /************************************************ @@ -566,20 +1882,20 @@ Search a leaf node for closest point closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_float(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, float *restrict closest_dist) +void search_leaf_double_int32_t(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, + uint32_t k, uint32_t *restrict closest_idx, double *restrict closest_dist) { - float cur_dist; + double cur_dist; uint32_t i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) { /* Get distance to query point */ - cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + cur_dist = calc_dist_double(&PA(start_idx + i, 0), point_coord, no_dims); /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_float(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_double_int32_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -598,10 +1914,10 @@ Search a leaf node for closest point with data point mask closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_float_mask(float *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, float *restrict point_coord, - uint32_t k, uint8_t *mask, uint32_t *restrict closest_idx, float *restrict closest_dist) +void search_leaf_double_int32_t_mask(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, + uint32_t k, uint8_t *mask, uint32_t *restrict closest_idx, double *restrict closest_dist) { - float cur_dist; + double cur_dist; uint32_t i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) @@ -612,11 +1928,11 @@ void search_leaf_float_mask(float *restrict pa, uint32_t *restrict pidx, int8_t continue; } /* Get distance to query point */ - cur_dist = calc_dist_float(&PA(start_idx + i, 0), point_coord, no_dims); + cur_dist = calc_dist_double(&PA(start_idx + i, 0), point_coord, no_dims); /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_float(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_double_int32_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -634,14 +1950,14 @@ Search subtree for nearest to query point closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t no_dims, float *point_coord, - float min_dist, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, - uint32_t *closest_idx, float *closest_dist) +void search_splitnode_double_int32_t(Node_double_int32_t *root, double *pa, uint32_t *pidx, int8_t no_dims, double *point_coord, + double min_dist, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, + uint32_t *closest_idx, double *closest_dist) { int8_t dim; - float dist_left, dist_right; - float new_offset; - float box_diff; + double dist_left, dist_right; + double new_offset; + double box_diff; /* Skip if distance bound exeeded */ if (min_dist > distance_upper_bound) @@ -656,11 +1972,11 @@ void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t { if (mask) { - search_leaf_float_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); + search_leaf_double_int32_t_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); } else { - search_leaf_float(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); + search_leaf_double_int32_t(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); } return; } @@ -675,7 +1991,7 @@ void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit */ - search_splitnode_float((Node_float *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int32_t((Node_double_int32_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Right of cutting plane. Update minimum distance. @@ -690,7 +2006,7 @@ void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_float((Node_float *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int32_t((Node_double_int32_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } else @@ -700,7 +2016,7 @@ void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_float((Node_float *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int32_t((Node_double_int32_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Left of cutting plane. Update minimum distance. @@ -715,7 +2031,7 @@ void search_splitnode_float(Node_float *root, float *pa, uint32_t *pidx, int8_t if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit*/ - search_splitnode_float((Node_float *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int32_t((Node_double_int32_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } } @@ -732,14 +2048,14 @@ Search for nearest neighbour for a set of query points closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_tree_float(Tree_float *tree, float *pa, float *point_coords, - uint32_t num_points, uint32_t k, float distance_upper_bound, - float eps, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists) +void search_tree_double_int32_t(Tree_double_int32_t *tree, double *pa, double *point_coords, + uint32_t num_points, uint32_t k, double distance_upper_bound, + double eps, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists) { - float min_dist; - float eps_fac = 1 / ((1 + eps) * (1 + eps)); + double min_dist; + double eps_fac = 1 / ((1 + eps) * (1 + eps)); int8_t no_dims = tree->no_dims; - float *bbox = tree->bbox; + double *bbox = tree->bbox; uint32_t *pidx = tree->pidx; uint32_t j = 0; #if defined(_MSC_VER) && defined(_OPENMP) @@ -749,7 +2065,7 @@ void search_tree_float(Tree_float *tree, float *pa, float *point_coords, uint32_t i; uint32_t local_num_points = num_points; #endif - Node_float *root = (Node_float *)tree->root; + Node_double_int32_t *root = (Node_double_int32_t *)tree->root; /* Queries are OpenMP enabled */ #pragma omp parallel @@ -762,11 +2078,11 @@ void search_tree_float(Tree_float *tree, float *pa, float *point_coords, { for (j = 0; j < k; j++) { - closest_idxs[i * k + j] = UINT32_MAX; - closest_dists[i * k + j] = DBL_MAX; + closest_idxs[i * k + j] = IDX_MAX_int32_t; + closest_dists[i * k + j] = DIST_MAX_double; } - min_dist = get_min_dist_float(point_coords + no_dims * i, no_dims, bbox); - search_splitnode_float(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, + min_dist = get_min_dist_double(point_coords + no_dims * i, no_dims, bbox); + search_splitnode_double_int32_t(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, k, distance_upper_bound, eps_fac, mask, &closest_idxs[i * k], &closest_dists[i * k]); } } @@ -781,7 +2097,7 @@ Insert point into priority queue cur_dist : distance to point inserted k : number of neighbours ************************************************/ -void insert_point_double(uint32_t *closest_idx, double *closest_dist, uint32_t pidx, double cur_dist, uint32_t k) +void insert_point_double_int64_t(uint64_t *closest_idx, double *closest_dist, uint64_t pidx, double cur_dist, uint64_t k) { int i; for (i = k - 1; i > 0; i--) @@ -809,11 +2125,11 @@ Get the bounding box of a set of points n : number of points bbox : bounding box (return) ************************************************/ -void get_bounding_box_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, double *bbox) +void get_bounding_box_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t n, double *bbox) { double cur; int8_t i, j; - uint32_t bbox_idx, i2; + uint64_t bbox_idx, i2; /* Use first data point to initialize */ for (i = 0; i < no_dims; i++) @@ -854,12 +2170,12 @@ The sliding midpoint rule is used for the partitioning. cut_val : value of cutting point (return) n_lo : number of point below cutting plane (return) ************************************************/ -int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *bbox, int8_t *cut_dim, double *cut_val, uint32_t *n_lo) +int partition_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *bbox, int8_t *cut_dim, double *cut_val, uint64_t *n_lo) { int8_t dim = 0, i; - uint32_t p, q, i2; + uint64_t p, q, i2; double size = 0, min_val, max_val, split, side_len, cur_val; - uint32_t end_idx = start_idx + n - 1; + uint64_t end_idx = start_idx + n - 1; /* Find largest bounding box side */ for (i = 0; i < no_dims; i++) @@ -905,7 +2221,7 @@ int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_ } else { - PASWAP(p, q); + PASWAP_int64_t(p, q); p++; q--; } @@ -919,7 +2235,7 @@ int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_ Minimum 1 point will be in lower box. */ - uint32_t j = start_idx; + uint64_t j = start_idx; split = PA(j, dim); for (i2 = start_idx + 1; i2 <= end_idx; i2++) { @@ -931,7 +2247,7 @@ int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_ split = cur_val; } } - PASWAP(j, start_idx); + PASWAP_int64_t(j, start_idx); p = start_idx + 1; } else if (p == end_idx + 1) @@ -941,7 +2257,7 @@ int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_ Minimum 1 point will be in higher box. */ - uint32_t j = end_idx; + uint64_t j = end_idx; split = PA(j, dim); for (i2 = start_idx; i2 < end_idx; i2++) { @@ -953,7 +2269,7 @@ int partition_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_ split = cur_val; } } - PASWAP(j, end_idx); + PASWAP_int64_t(j, end_idx); p = end_idx; } @@ -975,14 +2291,14 @@ Construct a sub tree over a range of data points. bsp : number of points per leaf bbox : bounding box of set of data points ************************************************/ -Node_double* construct_subtree_double(double *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, double *bbox) +Node_double_int64_t* construct_subtree_double_int64_t(double *pa, uint64_t *pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, uint64_t bsp, double *bbox) { /* Create new node */ int is_leaf = (n <= bsp); - Node_double *root = create_node_double(start_idx, n, is_leaf); + Node_double_int64_t *root = create_node_double_int64_t(start_idx, n, is_leaf); int rval; int8_t cut_dim; - uint32_t n_lo; + uint64_t n_lo; double cut_val, lv, hv; if (is_leaf) { @@ -993,7 +2309,7 @@ Node_double* construct_subtree_double(double *pa, uint32_t *pidx, int8_t no_dims { /* Make split node */ /* Partition data set and set node info */ - rval = partition_double(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); + rval = partition_double_int64_t(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); if (rval == 1) { root->cut_dim = -1; @@ -1012,12 +2328,12 @@ Node_double* construct_subtree_double(double *pa, uint32_t *pidx, int8_t no_dims /* Update bounding box before call to lower subset and restore after */ bbox[2 * cut_dim + 1] = cut_val; - root->left_child = (struct Node_double *)construct_subtree_double(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); + root->left_child = (struct Node_double_int64_t *)construct_subtree_double_int64_t(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); bbox[2 * cut_dim + 1] = hv; /* Update bounding box before call to higher subset and restore after */ bbox[2 * cut_dim] = cut_val; - root->right_child = (struct Node_double *)construct_subtree_double(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); + root->right_child = (struct Node_double_int64_t *)construct_subtree_double_int64_t(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); bbox[2 * cut_dim] = lv; } return root; @@ -1031,28 +2347,28 @@ Construct a tree over data points. n : number of data points bsp : number of points per leaf ************************************************/ -Tree_double* construct_tree_double(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp) +Tree_double_int64_t* construct_tree_double_int64_t(double *pa, int8_t no_dims, uint64_t n, uint64_t bsp) { - Tree_double *tree = (Tree_double *)malloc(sizeof(Tree_double)); - uint32_t i; - uint32_t *pidx; + Tree_double_int64_t *tree = (Tree_double_int64_t *)malloc(sizeof(Tree_double_int64_t)); + uint64_t i; + uint64_t *pidx; double *bbox; tree->no_dims = no_dims; /* Initialize permutation array */ - pidx = (uint32_t *)malloc(sizeof(uint32_t) * n); + pidx = (uint64_t *)malloc(sizeof(uint64_t) * n); for (i = 0; i < n; i++) { pidx[i] = i; } bbox = (double *)malloc(2 * sizeof(double) * no_dims); - get_bounding_box_double(pa, pidx, no_dims, n, bbox); + get_bounding_box_double_int64_t(pa, pidx, no_dims, n, bbox); tree->bbox = bbox; /* Construct subtree on full dataset */ - tree->root = (struct Node_double *)construct_subtree_double(pa, pidx, no_dims, 0, n, bsp, bbox); + tree->root = (struct Node_double_int64_t *)construct_subtree_double_int64_t(pa, pidx, no_dims, 0, n, bsp, bbox); tree->pidx = pidx; return tree; @@ -1064,9 +2380,9 @@ Create a tree node. start_idx : index of first data point to use n : number of data points ************************************************/ -Node_double* create_node_double(uint32_t start_idx, uint32_t n, int is_leaf) +Node_double_int64_t* create_node_double_int64_t(uint64_t start_idx, uint64_t n, int is_leaf) { - Node_double *new_node; + Node_double_int64_t *new_node; if (is_leaf) { /* @@ -1074,11 +2390,11 @@ Node_double* create_node_double(uint32_t start_idx, uint32_t n, int is_leaf) This relies on the C99 specification of struct layout conservation and padding and that dereferencing is never attempted for the node pointers in a leaf. */ - new_node = (Node_double *)malloc(sizeof(Node_double) - 2 * sizeof(Node_double *)); + new_node = (Node_double_int64_t *)malloc(sizeof(Node_double_int64_t) - 2 * sizeof(Node_double_int64_t *)); } else { - new_node = (Node_double *)malloc(sizeof(Node_double)); + new_node = (Node_double_int64_t *)malloc(sizeof(Node_double_int64_t)); } new_node->n = n; new_node->start_idx = start_idx; @@ -1090,12 +2406,12 @@ Delete subtree Params: root : root node of subtree to delete ************************************************/ -void delete_subtree_double(Node_double *root) +void delete_subtree_double_int64_t(Node_double_int64_t *root) { if (root->cut_dim != -1) { - delete_subtree_double((Node_double *)root->left_child); - delete_subtree_double((Node_double *)root->right_child); + delete_subtree_double_int64_t((Node_double_int64_t *)root->left_child); + delete_subtree_double_int64_t((Node_double_int64_t *)root->right_child); } free(root); } @@ -1105,9 +2421,9 @@ Delete tree Params: tree : Tree struct of kd tree ************************************************/ -void delete_tree_double(Tree_double *tree) +void delete_tree_double_int64_t(Tree_double_int64_t *tree) { - delete_subtree_double((Node_double *)tree->root); + delete_subtree_double_int64_t((Node_double_int64_t *)tree->root); free(tree->bbox); free(tree->pidx); free(tree); @@ -1116,7 +2432,7 @@ void delete_tree_double(Tree_double *tree) /************************************************ Print ************************************************/ -void print_tree_double(Node_double *root, int level) +void print_tree_double_int64_t(Node_double_int64_t *root, int level) { int i; for (i = 0; i < level; i++) @@ -1125,77 +2441,9 @@ void print_tree_double(Node_double *root, int level) } printf("(cut_val: %f, cut_dim: %i)\n", root->cut_val, root->cut_dim); if (root->cut_dim != -1) - print_tree_double((Node_double *)root->left_child, level + 1); + print_tree_double_int64_t((Node_double_int64_t *)root->left_child, level + 1); if (root->cut_dim != -1) - print_tree_double((Node_double *)root->right_child, level + 1); -} - -/************************************************ -Calculate squared cartesian distance between points -Params: - point1_coord : point 1 - point2_coord : point 2 -************************************************/ -double calc_dist_double(double *point1_coord, double *point2_coord, int8_t no_dims) -{ - /* Calculate squared distance */ - double dist = 0, dim_dist; - int8_t i; - for (i = 0; i < no_dims; i++) - { - dim_dist = point2_coord[i] - point1_coord[i]; - dist += dim_dist * dim_dist; - } - return dist; -} - -/************************************************ -Get squared distance from point to cube in specified dimension -Params: - dim : dimension - point_coord : cartesian coordinates of point - bbox : cube -************************************************/ -double get_cube_offset_double(int8_t dim, double *point_coord, double *bbox) -{ - double dim_coord = point_coord[dim]; - - if (dim_coord < bbox[2 * dim]) - { - /* Left of cube in dimension */ - return dim_coord - bbox[2 * dim]; - } - else if (dim_coord > bbox[2 * dim + 1]) - { - /* Right of cube in dimension */ - return dim_coord - bbox[2 * dim + 1]; - } - else - { - /* Inside cube in dimension */ - return 0.; - } -} - -/************************************************ -Get minimum squared distance between point and cube. -Params: - point_coord : cartesian coordinates of point - no_dims : number of dimensions - bbox : cube -************************************************/ -double get_min_dist_double(double *point_coord, int8_t no_dims, double *bbox) -{ - double cube_offset = 0, cube_offset_dim; - int8_t i; - - for (i = 0; i < no_dims; i++) - { - cube_offset_dim = get_cube_offset_double(i, point_coord, bbox); - cube_offset += cube_offset_dim * cube_offset_dim; - } - - return cube_offset; + print_tree_double_int64_t((Node_double_int64_t *)root->right_child, level + 1); } /************************************************ @@ -1210,11 +2458,11 @@ Search a leaf node for closest point closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_double(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, double *restrict closest_dist) +void search_leaf_double_int64_t(double *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *restrict point_coord, + uint64_t k, uint64_t *restrict closest_idx, double *restrict closest_dist) { double cur_dist; - uint32_t i; + uint64_t i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) { @@ -1223,7 +2471,7 @@ void search_leaf_double(double *restrict pa, uint32_t *restrict pidx, int8_t no_ /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_double(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_double_int64_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -1242,11 +2490,11 @@ Search a leaf node for closest point with data point mask closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_double_mask(double *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, double *restrict point_coord, - uint32_t k, uint8_t *mask, uint32_t *restrict closest_idx, double *restrict closest_dist) +void search_leaf_double_int64_t_mask(double *restrict pa, uint64_t *restrict pidx, int8_t no_dims, uint64_t start_idx, uint64_t n, double *restrict point_coord, + uint64_t k, uint8_t *mask, uint64_t *restrict closest_idx, double *restrict closest_dist) { double cur_dist; - uint32_t i; + uint64_t i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) { @@ -1260,7 +2508,7 @@ void search_leaf_double_mask(double *restrict pa, uint32_t *restrict pidx, int8_ /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_double(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_double_int64_t(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -1278,9 +2526,9 @@ Search subtree for nearest to query point closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8_t no_dims, double *point_coord, - double min_dist, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, - uint32_t *closest_idx, double *closest_dist) +void search_splitnode_double_int64_t(Node_double_int64_t *root, double *pa, uint64_t *pidx, int8_t no_dims, double *point_coord, + double min_dist, uint64_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, + uint64_t *closest_idx, double *closest_dist) { int8_t dim; double dist_left, dist_right; @@ -1300,11 +2548,11 @@ void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8 { if (mask) { - search_leaf_double_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); + search_leaf_double_int64_t_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); } else { - search_leaf_double(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); + search_leaf_double_int64_t(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); } return; } @@ -1319,7 +2567,7 @@ void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8 if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit */ - search_splitnode_double((Node_double *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int64_t((Node_double_int64_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Right of cutting plane. Update minimum distance. @@ -1334,7 +2582,7 @@ void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8 if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_double((Node_double *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int64_t((Node_double_int64_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } else @@ -1344,7 +2592,7 @@ void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8 if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_double((Node_double *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int64_t((Node_double_int64_t *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Left of cutting plane. Update minimum distance. @@ -1359,7 +2607,7 @@ void search_splitnode_double(Node_double *root, double *pa, uint32_t *pidx, int8 if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit*/ - search_splitnode_double((Node_double *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_double_int64_t((Node_double_int64_t *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } } @@ -1376,24 +2624,24 @@ Search for nearest neighbour for a set of query points closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_tree_double(Tree_double *tree, double *pa, double *point_coords, - uint32_t num_points, uint32_t k, double distance_upper_bound, - double eps, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists) +void search_tree_double_int64_t(Tree_double_int64_t *tree, double *pa, double *point_coords, + uint64_t num_points, uint64_t k, double distance_upper_bound, + double eps, uint8_t *mask, uint64_t *closest_idxs, double *closest_dists) { double min_dist; double eps_fac = 1 / ((1 + eps) * (1 + eps)); int8_t no_dims = tree->no_dims; double *bbox = tree->bbox; - uint32_t *pidx = tree->pidx; - uint32_t j = 0; + uint64_t *pidx = tree->pidx; + uint64_t j = 0; #if defined(_MSC_VER) && defined(_OPENMP) - int32_t i = 0; - int32_t local_num_points = (int32_t) num_points; + int64_t i = 0; + int64_t local_num_points = (int64_t) num_points; #else - uint32_t i; - uint32_t local_num_points = num_points; + uint64_t i; + uint64_t local_num_points = num_points; #endif - Node_double *root = (Node_double *)tree->root; + Node_double_int64_t *root = (Node_double_int64_t *)tree->root; /* Queries are OpenMP enabled */ #pragma omp parallel @@ -1406,11 +2654,11 @@ void search_tree_double(Tree_double *tree, double *pa, double *point_coords, { for (j = 0; j < k; j++) { - closest_idxs[i * k + j] = UINT32_MAX; - closest_dists[i * k + j] = DBL_MAX; + closest_idxs[i * k + j] = IDX_MAX_int64_t; + closest_dists[i * k + j] = DIST_MAX_double; } min_dist = get_min_dist_double(point_coords + no_dims * i, no_dims, bbox); - search_splitnode_double(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, + search_splitnode_double_int64_t(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, k, distance_upper_bound, eps_fac, mask, &closest_idxs[i * k], &closest_dists[i * k]); } } diff --git a/pykdtree/_kdtree_core.c.mako b/pykdtree/_kdtree_core.c.mako index 6e3eb8b..dbe31f8 100644 --- a/pykdtree/_kdtree_core.c.mako +++ b/pykdtree/_kdtree_core.c.mako @@ -29,65 +29,148 @@ Anne M. Archibald and libANN by David M. Mount and Sunil Arya. #include #define PA(i,d) (pa[no_dims * pidx[i] + d]) -#define PASWAP(a,b) { uint32_t tmp = pidx[a]; pidx[a] = pidx[b]; pidx[b] = tmp; } +% for ITYPE in ['int32_t', 'int64_t']: +#define PASWAP_${ITYPE}(a,b) { u${ITYPE} tmp = pidx[a]; pidx[a] = pidx[b]; pidx[b] = tmp; } +% endfor + +#define IDX_MAX_int32_t UINT32_MAX +#define IDX_MAX_int64_t UINT64_MAX +#define DIST_MAX_float FLT_MAX +#define DIST_MAX_double DBL_MAX #ifdef _MSC_VER #define restrict __restrict #endif % for DTYPE in ['float', 'double']: +% for ITYPE in ['int32_t', 'int64_t']: typedef struct { ${DTYPE} cut_val; int8_t cut_dim; - uint32_t start_idx; - uint32_t n; + u${ITYPE} start_idx; + u${ITYPE} n; ${DTYPE} cut_bounds_lv; ${DTYPE} cut_bounds_hv; - struct Node_${DTYPE} *left_child; - struct Node_${DTYPE} *right_child; -} Node_${DTYPE}; + struct Node_${DTYPE}_${ITYPE} *left_child; + struct Node_${DTYPE}_${ITYPE} *right_child; +} Node_${DTYPE}_${ITYPE}; typedef struct { ${DTYPE} *bbox; int8_t no_dims; - uint32_t *pidx; - struct Node_${DTYPE} *root; -} Tree_${DTYPE}; + u${ITYPE} *pidx; + struct Node_${DTYPE}_${ITYPE} *root; +} Tree_${DTYPE}_${ITYPE}; +% endfor % endfor % for DTYPE in ['float', 'double']: -void insert_point_${DTYPE}(uint32_t *closest_idx, ${DTYPE} *closest_dist, uint32_t pidx, ${DTYPE} cur_dist, uint32_t k); -void get_bounding_box_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, ${DTYPE} *bbox); -int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *bbox, int8_t *cut_dim, - ${DTYPE} *cut_val, uint32_t *n_lo); -Tree_${DTYPE}* construct_tree_${DTYPE}(${DTYPE} *pa, int8_t no_dims, uint32_t n, uint32_t bsp); -Node_${DTYPE}* construct_subtree_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, ${DTYPE} *bbox); -Node_${DTYPE} * create_node_${DTYPE}(uint32_t start_idx, uint32_t n, int is_leaf); -void delete_subtree_${DTYPE}(Node_${DTYPE} *root); -void delete_tree_${DTYPE}(Tree_${DTYPE} *tree); -void print_tree_${DTYPE}(Node_${DTYPE} *root, int level); ${DTYPE} calc_dist_${DTYPE}(${DTYPE} *point1_coord, ${DTYPE} *point2_coord, int8_t no_dims); ${DTYPE} get_cube_offset_${DTYPE}(int8_t dim, ${DTYPE} *point_coord, ${DTYPE} *bbox); ${DTYPE} get_min_dist_${DTYPE}(${DTYPE} *point_coord, int8_t no_dims, ${DTYPE} *bbox); -void search_leaf_${DTYPE}(${DTYPE} *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, ${DTYPE} *restrict closest_dist); -void search_leaf_${DTYPE}_mask(${DTYPE} *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *restrict point_coord, - uint32_t k, uint8_t *restrict mask, uint32_t *restrict closest_idx, ${DTYPE} *restrict closest_dist); -void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, ${DTYPE} *point_coord, - ${DTYPE} min_dist, uint32_t k, ${DTYPE} distance_upper_bound, ${DTYPE} eps_fac, uint8_t *mask, uint32_t * closest_idx, ${DTYPE} *closest_dist); -void search_tree_${DTYPE}(Tree_${DTYPE} *tree, ${DTYPE} *pa, ${DTYPE} *point_coords, - uint32_t num_points, uint32_t k, ${DTYPE} distance_upper_bound, - ${DTYPE} eps, uint8_t *mask, uint32_t *closest_idxs, ${DTYPE} *closest_dists); +% for ITYPE in ['int32_t', 'int64_t']: + +void insert_point_${DTYPE}_${ITYPE}(u${ITYPE} *closest_idx, ${DTYPE} *closest_dist, u${ITYPE} pidx, ${DTYPE} cur_dist, u${ITYPE} k); +void get_bounding_box_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} n, ${DTYPE} *bbox); +int partition_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *bbox, int8_t *cut_dim, + ${DTYPE} *cut_val, u${ITYPE} *n_lo); +Tree_${DTYPE}_${ITYPE}* construct_tree_${DTYPE}_${ITYPE}(${DTYPE} *pa, int8_t no_dims, u${ITYPE} n, u${ITYPE} bsp); +Node_${DTYPE}_${ITYPE}* construct_subtree_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, u${ITYPE} bsp, ${DTYPE} *bbox); +Node_${DTYPE}_${ITYPE} * create_node_${DTYPE}_${ITYPE}(u${ITYPE} start_idx, u${ITYPE} n, int is_leaf); +void delete_subtree_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root); +void delete_tree_${DTYPE}_${ITYPE}(Tree_${DTYPE}_${ITYPE} *tree); +void print_tree_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root, int level); +void search_leaf_${DTYPE}_${ITYPE}(${DTYPE} *restrict pa, u${ITYPE} *restrict pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *restrict point_coord, + u${ITYPE} k, u${ITYPE} *restrict closest_idx, ${DTYPE} *restrict closest_dist); +void search_leaf_${DTYPE}_${ITYPE}_mask(${DTYPE} *restrict pa, u${ITYPE} *restrict pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *restrict point_coord, + u${ITYPE} k, uint8_t *restrict mask, u${ITYPE} *restrict closest_idx, ${DTYPE} *restrict closest_dist); +void search_splitnode_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root, ${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, ${DTYPE} *point_coord, + ${DTYPE} min_dist, u${ITYPE} k, ${DTYPE} distance_upper_bound, ${DTYPE} eps_fac, uint8_t *mask, u${ITYPE} * closest_idx, ${DTYPE} *closest_dist); +void search_tree_${DTYPE}_${ITYPE}(Tree_${DTYPE}_${ITYPE} *tree, ${DTYPE} *pa, ${DTYPE} *point_coords, + u${ITYPE} num_points, u${ITYPE} k, ${DTYPE} distance_upper_bound, + ${DTYPE} eps, uint8_t *mask, u${ITYPE} *closest_idxs, ${DTYPE} *closest_dists); + +% endfor % endfor % for DTYPE in ['float', 'double']: +/************************************************ +Calculate squared cartesian distance between points +Params: + point1_coord : point 1 + point2_coord : point 2 +************************************************/ +${DTYPE} calc_dist_${DTYPE}(${DTYPE} *point1_coord, ${DTYPE} *point2_coord, int8_t no_dims) +{ + /* Calculate squared distance */ + ${DTYPE} dist = 0, dim_dist; + int8_t i; + for (i = 0; i < no_dims; i++) + { + dim_dist = point2_coord[i] - point1_coord[i]; + dist += dim_dist * dim_dist; + } + return dist; +} + +/************************************************ +Get squared distance from point to cube in specified dimension +Params: + dim : dimension + point_coord : cartesian coordinates of point + bbox : cube +************************************************/ +${DTYPE} get_cube_offset_${DTYPE}(int8_t dim, ${DTYPE} *point_coord, ${DTYPE} *bbox) +{ + ${DTYPE} dim_coord = point_coord[dim]; + + if (dim_coord < bbox[2 * dim]) + { + /* Left of cube in dimension */ + return dim_coord - bbox[2 * dim]; + } + else if (dim_coord > bbox[2 * dim + 1]) + { + /* Right of cube in dimension */ + return dim_coord - bbox[2 * dim + 1]; + } + else + { + /* Inside cube in dimension */ + return 0.; + } +} + +/************************************************ +Get minimum squared distance between point and cube. +Params: + point_coord : cartesian coordinates of point + no_dims : number of dimensions + bbox : cube +************************************************/ +${DTYPE} get_min_dist_${DTYPE}(${DTYPE} *point_coord, int8_t no_dims, ${DTYPE} *bbox) +{ + ${DTYPE} cube_offset = 0, cube_offset_dim; + int8_t i; + + for (i = 0; i < no_dims; i++) + { + cube_offset_dim = get_cube_offset_${DTYPE}(i, point_coord, bbox); + cube_offset += cube_offset_dim * cube_offset_dim; + } + + return cube_offset; +} + +% for ITYPE in ['int32_t', 'int64_t']: + /************************************************ Insert point into priority queue Params: @@ -97,7 +180,7 @@ Params: cur_dist : distance to point inserted k : number of neighbours ************************************************/ -void insert_point_${DTYPE}(uint32_t *closest_idx, ${DTYPE} *closest_dist, uint32_t pidx, ${DTYPE} cur_dist, uint32_t k) +void insert_point_${DTYPE}_${ITYPE}(u${ITYPE} *closest_idx, ${DTYPE} *closest_dist, u${ITYPE} pidx, ${DTYPE} cur_dist, u${ITYPE} k) { int i; for (i = k - 1; i > 0; i--) @@ -125,11 +208,11 @@ Params: n : number of points bbox : bounding box (return) ************************************************/ -void get_bounding_box_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t n, ${DTYPE} *bbox) +void get_bounding_box_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} n, ${DTYPE} *bbox) { ${DTYPE} cur; int8_t i, j; - uint32_t bbox_idx, i2; + u${ITYPE} bbox_idx, i2; /* Use first data point to initialize */ for (i = 0; i < no_dims; i++) @@ -170,12 +253,12 @@ Params: cut_val : value of cutting point (return) n_lo : number of point below cutting plane (return) ************************************************/ -int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *bbox, int8_t *cut_dim, ${DTYPE} *cut_val, uint32_t *n_lo) +int partition_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *bbox, int8_t *cut_dim, ${DTYPE} *cut_val, u${ITYPE} *n_lo) { int8_t dim = 0, i; - uint32_t p, q, i2; + u${ITYPE} p, q, i2; ${DTYPE} size = 0, min_val, max_val, split, side_len, cur_val; - uint32_t end_idx = start_idx + n - 1; + u${ITYPE} end_idx = start_idx + n - 1; /* Find largest bounding box side */ for (i = 0; i < no_dims; i++) @@ -221,7 +304,7 @@ int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t st } else { - PASWAP(p, q); + PASWAP_${ITYPE}(p, q); p++; q--; } @@ -235,7 +318,7 @@ int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t st Minimum 1 point will be in lower box. */ - uint32_t j = start_idx; + u${ITYPE} j = start_idx; split = PA(j, dim); for (i2 = start_idx + 1; i2 <= end_idx; i2++) { @@ -247,7 +330,7 @@ int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t st split = cur_val; } } - PASWAP(j, start_idx); + PASWAP_${ITYPE}(j, start_idx); p = start_idx + 1; } else if (p == end_idx + 1) @@ -257,7 +340,7 @@ int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t st Minimum 1 point will be in higher box. */ - uint32_t j = end_idx; + u${ITYPE} j = end_idx; split = PA(j, dim); for (i2 = start_idx; i2 < end_idx; i2++) { @@ -269,7 +352,7 @@ int partition_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t st split = cur_val; } } - PASWAP(j, end_idx); + PASWAP_${ITYPE}(j, end_idx); p = end_idx; } @@ -291,14 +374,14 @@ Params: bsp : number of points per leaf bbox : bounding box of set of data points ************************************************/ -Node_${DTYPE}* construct_subtree_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, uint32_t bsp, ${DTYPE} *bbox) +Node_${DTYPE}_${ITYPE}* construct_subtree_${DTYPE}_${ITYPE}(${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, u${ITYPE} bsp, ${DTYPE} *bbox) { /* Create new node */ int is_leaf = (n <= bsp); - Node_${DTYPE} *root = create_node_${DTYPE}(start_idx, n, is_leaf); + Node_${DTYPE}_${ITYPE} *root = create_node_${DTYPE}_${ITYPE}(start_idx, n, is_leaf); int rval; int8_t cut_dim; - uint32_t n_lo; + u${ITYPE} n_lo; ${DTYPE} cut_val, lv, hv; if (is_leaf) { @@ -309,7 +392,7 @@ Node_${DTYPE}* construct_subtree_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t n { /* Make split node */ /* Partition data set and set node info */ - rval = partition_${DTYPE}(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); + rval = partition_${DTYPE}_${ITYPE}(pa, pidx, no_dims, start_idx, n, bbox, &cut_dim, &cut_val, &n_lo); if (rval == 1) { root->cut_dim = -1; @@ -328,12 +411,12 @@ Node_${DTYPE}* construct_subtree_${DTYPE}(${DTYPE} *pa, uint32_t *pidx, int8_t n /* Update bounding box before call to lower subset and restore after */ bbox[2 * cut_dim + 1] = cut_val; - root->left_child = (struct Node_${DTYPE} *)construct_subtree_${DTYPE}(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); + root->left_child = (struct Node_${DTYPE}_${ITYPE} *)construct_subtree_${DTYPE}_${ITYPE}(pa, pidx, no_dims, start_idx, n_lo, bsp, bbox); bbox[2 * cut_dim + 1] = hv; /* Update bounding box before call to higher subset and restore after */ bbox[2 * cut_dim] = cut_val; - root->right_child = (struct Node_${DTYPE} *)construct_subtree_${DTYPE}(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); + root->right_child = (struct Node_${DTYPE}_${ITYPE} *)construct_subtree_${DTYPE}_${ITYPE}(pa, pidx, no_dims, start_idx + n_lo, n - n_lo, bsp, bbox); bbox[2 * cut_dim] = lv; } return root; @@ -347,28 +430,28 @@ Params: n : number of data points bsp : number of points per leaf ************************************************/ -Tree_${DTYPE}* construct_tree_${DTYPE}(${DTYPE} *pa, int8_t no_dims, uint32_t n, uint32_t bsp) +Tree_${DTYPE}_${ITYPE}* construct_tree_${DTYPE}_${ITYPE}(${DTYPE} *pa, int8_t no_dims, u${ITYPE} n, u${ITYPE} bsp) { - Tree_${DTYPE} *tree = (Tree_${DTYPE} *)malloc(sizeof(Tree_${DTYPE})); - uint32_t i; - uint32_t *pidx; + Tree_${DTYPE}_${ITYPE} *tree = (Tree_${DTYPE}_${ITYPE} *)malloc(sizeof(Tree_${DTYPE}_${ITYPE})); + u${ITYPE} i; + u${ITYPE} *pidx; ${DTYPE} *bbox; tree->no_dims = no_dims; /* Initialize permutation array */ - pidx = (uint32_t *)malloc(sizeof(uint32_t) * n); + pidx = (u${ITYPE} *)malloc(sizeof(u${ITYPE}) * n); for (i = 0; i < n; i++) { pidx[i] = i; } bbox = (${DTYPE} *)malloc(2 * sizeof(${DTYPE}) * no_dims); - get_bounding_box_${DTYPE}(pa, pidx, no_dims, n, bbox); + get_bounding_box_${DTYPE}_${ITYPE}(pa, pidx, no_dims, n, bbox); tree->bbox = bbox; /* Construct subtree on full dataset */ - tree->root = (struct Node_${DTYPE} *)construct_subtree_${DTYPE}(pa, pidx, no_dims, 0, n, bsp, bbox); + tree->root = (struct Node_${DTYPE}_${ITYPE} *)construct_subtree_${DTYPE}_${ITYPE}(pa, pidx, no_dims, 0, n, bsp, bbox); tree->pidx = pidx; return tree; @@ -380,9 +463,9 @@ Params: start_idx : index of first data point to use n : number of data points ************************************************/ -Node_${DTYPE}* create_node_${DTYPE}(uint32_t start_idx, uint32_t n, int is_leaf) +Node_${DTYPE}_${ITYPE}* create_node_${DTYPE}_${ITYPE}(u${ITYPE} start_idx, u${ITYPE} n, int is_leaf) { - Node_${DTYPE} *new_node; + Node_${DTYPE}_${ITYPE} *new_node; if (is_leaf) { /* @@ -390,11 +473,11 @@ Node_${DTYPE}* create_node_${DTYPE}(uint32_t start_idx, uint32_t n, int is_leaf) This relies on the C99 specification of struct layout conservation and padding and that dereferencing is never attempted for the node pointers in a leaf. */ - new_node = (Node_${DTYPE} *)malloc(sizeof(Node_${DTYPE}) - 2 * sizeof(Node_${DTYPE} *)); + new_node = (Node_${DTYPE}_${ITYPE} *)malloc(sizeof(Node_${DTYPE}_${ITYPE}) - 2 * sizeof(Node_${DTYPE}_${ITYPE} *)); } else { - new_node = (Node_${DTYPE} *)malloc(sizeof(Node_${DTYPE})); + new_node = (Node_${DTYPE}_${ITYPE} *)malloc(sizeof(Node_${DTYPE}_${ITYPE})); } new_node->n = n; new_node->start_idx = start_idx; @@ -406,12 +489,12 @@ Delete subtree Params: root : root node of subtree to delete ************************************************/ -void delete_subtree_${DTYPE}(Node_${DTYPE} *root) +void delete_subtree_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root) { if (root->cut_dim != -1) { - delete_subtree_${DTYPE}((Node_${DTYPE} *)root->left_child); - delete_subtree_${DTYPE}((Node_${DTYPE} *)root->right_child); + delete_subtree_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->left_child); + delete_subtree_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->right_child); } free(root); } @@ -421,9 +504,9 @@ Delete tree Params: tree : Tree struct of kd tree ************************************************/ -void delete_tree_${DTYPE}(Tree_${DTYPE} *tree) +void delete_tree_${DTYPE}_${ITYPE}(Tree_${DTYPE}_${ITYPE} *tree) { - delete_subtree_${DTYPE}((Node_${DTYPE} *)tree->root); + delete_subtree_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)tree->root); free(tree->bbox); free(tree->pidx); free(tree); @@ -432,7 +515,7 @@ void delete_tree_${DTYPE}(Tree_${DTYPE} *tree) /************************************************ Print ************************************************/ -void print_tree_${DTYPE}(Node_${DTYPE} *root, int level) +void print_tree_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root, int level) { int i; for (i = 0; i < level; i++) @@ -441,77 +524,9 @@ void print_tree_${DTYPE}(Node_${DTYPE} *root, int level) } printf("(cut_val: %f, cut_dim: %i)\n", root->cut_val, root->cut_dim); if (root->cut_dim != -1) - print_tree_${DTYPE}((Node_${DTYPE} *)root->left_child, level + 1); + print_tree_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->left_child, level + 1); if (root->cut_dim != -1) - print_tree_${DTYPE}((Node_${DTYPE} *)root->right_child, level + 1); -} - -/************************************************ -Calculate squared cartesian distance between points -Params: - point1_coord : point 1 - point2_coord : point 2 -************************************************/ -${DTYPE} calc_dist_${DTYPE}(${DTYPE} *point1_coord, ${DTYPE} *point2_coord, int8_t no_dims) -{ - /* Calculate squared distance */ - ${DTYPE} dist = 0, dim_dist; - int8_t i; - for (i = 0; i < no_dims; i++) - { - dim_dist = point2_coord[i] - point1_coord[i]; - dist += dim_dist * dim_dist; - } - return dist; -} - -/************************************************ -Get squared distance from point to cube in specified dimension -Params: - dim : dimension - point_coord : cartesian coordinates of point - bbox : cube -************************************************/ -${DTYPE} get_cube_offset_${DTYPE}(int8_t dim, ${DTYPE} *point_coord, ${DTYPE} *bbox) -{ - ${DTYPE} dim_coord = point_coord[dim]; - - if (dim_coord < bbox[2 * dim]) - { - /* Left of cube in dimension */ - return dim_coord - bbox[2 * dim]; - } - else if (dim_coord > bbox[2 * dim + 1]) - { - /* Right of cube in dimension */ - return dim_coord - bbox[2 * dim + 1]; - } - else - { - /* Inside cube in dimension */ - return 0.; - } -} - -/************************************************ -Get minimum squared distance between point and cube. -Params: - point_coord : cartesian coordinates of point - no_dims : number of dimensions - bbox : cube -************************************************/ -${DTYPE} get_min_dist_${DTYPE}(${DTYPE} *point_coord, int8_t no_dims, ${DTYPE} *bbox) -{ - ${DTYPE} cube_offset = 0, cube_offset_dim; - int8_t i; - - for (i = 0; i < no_dims; i++) - { - cube_offset_dim = get_cube_offset_${DTYPE}(i, point_coord, bbox); - cube_offset += cube_offset_dim * cube_offset_dim; - } - - return cube_offset; + print_tree_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->right_child, level + 1); } /************************************************ @@ -526,11 +541,11 @@ Params: closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_${DTYPE}(${DTYPE} *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *restrict point_coord, - uint32_t k, uint32_t *restrict closest_idx, ${DTYPE} *restrict closest_dist) +void search_leaf_${DTYPE}_${ITYPE}(${DTYPE} *restrict pa, u${ITYPE} *restrict pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *restrict point_coord, + u${ITYPE} k, u${ITYPE} *restrict closest_idx, ${DTYPE} *restrict closest_dist) { ${DTYPE} cur_dist; - uint32_t i; + u${ITYPE} i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) { @@ -539,7 +554,7 @@ void search_leaf_${DTYPE}(${DTYPE} *restrict pa, uint32_t *restrict pidx, int8_t /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_${DTYPE}(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_${DTYPE}_${ITYPE}(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -558,11 +573,11 @@ Params: closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_leaf_${DTYPE}_mask(${DTYPE} *restrict pa, uint32_t *restrict pidx, int8_t no_dims, uint32_t start_idx, uint32_t n, ${DTYPE} *restrict point_coord, - uint32_t k, uint8_t *mask, uint32_t *restrict closest_idx, ${DTYPE} *restrict closest_dist) +void search_leaf_${DTYPE}_${ITYPE}_mask(${DTYPE} *restrict pa, u${ITYPE} *restrict pidx, int8_t no_dims, u${ITYPE} start_idx, u${ITYPE} n, ${DTYPE} *restrict point_coord, + u${ITYPE} k, uint8_t *mask, u${ITYPE} *restrict closest_idx, ${DTYPE} *restrict closest_dist) { ${DTYPE} cur_dist; - uint32_t i; + u${ITYPE} i; /* Loop through all points in leaf */ for (i = 0; i < n; i++) { @@ -576,7 +591,7 @@ void search_leaf_${DTYPE}_mask(${DTYPE} *restrict pa, uint32_t *restrict pidx, i /* Update closest info if new point is closest so far*/ if (cur_dist < closest_dist[k - 1]) { - insert_point_${DTYPE}(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); + insert_point_${DTYPE}_${ITYPE}(closest_idx, closest_dist, pidx[start_idx + i], cur_dist, k); } } } @@ -594,9 +609,9 @@ Params: closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx, int8_t no_dims, ${DTYPE} *point_coord, - ${DTYPE} min_dist, uint32_t k, ${DTYPE} distance_upper_bound, ${DTYPE} eps_fac, uint8_t *mask, - uint32_t *closest_idx, ${DTYPE} *closest_dist) +void search_splitnode_${DTYPE}_${ITYPE}(Node_${DTYPE}_${ITYPE} *root, ${DTYPE} *pa, u${ITYPE} *pidx, int8_t no_dims, ${DTYPE} *point_coord, + ${DTYPE} min_dist, u${ITYPE} k, ${DTYPE} distance_upper_bound, ${DTYPE} eps_fac, uint8_t *mask, + u${ITYPE} *closest_idx, ${DTYPE} *closest_dist) { int8_t dim; ${DTYPE} dist_left, dist_right; @@ -616,11 +631,11 @@ void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx { if (mask) { - search_leaf_${DTYPE}_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); + search_leaf_${DTYPE}_${ITYPE}_mask(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, mask, closest_idx, closest_dist); } else { - search_leaf_${DTYPE}(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); + search_leaf_${DTYPE}_${ITYPE}(pa, pidx, no_dims, root->start_idx, root->n, point_coord, k, closest_idx, closest_dist); } return; } @@ -635,7 +650,7 @@ void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit */ - search_splitnode_${DTYPE}((Node_${DTYPE} *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Right of cutting plane. Update minimum distance. @@ -650,7 +665,7 @@ void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_${DTYPE}((Node_${DTYPE} *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } else @@ -660,7 +675,7 @@ void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx if (dist_right < closest_dist[k - 1] * eps_fac) { /* Search right subtree if minimum distance is below limit*/ - search_splitnode_${DTYPE}((Node_${DTYPE} *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->right_child, pa, pidx, no_dims, point_coord, dist_right, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } /* Left of cutting plane. Update minimum distance. @@ -675,7 +690,7 @@ void search_splitnode_${DTYPE}(Node_${DTYPE} *root, ${DTYPE} *pa, uint32_t *pidx if (dist_left < closest_dist[k - 1] * eps_fac) { /* Search left subtree if minimum distance is below limit*/ - search_splitnode_${DTYPE}((Node_${DTYPE} *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); + search_splitnode_${DTYPE}_${ITYPE}((Node_${DTYPE}_${ITYPE} *)root->left_child, pa, pidx, no_dims, point_coord, dist_left, k, distance_upper_bound, eps_fac, mask, closest_idx, closest_dist); } } } @@ -692,24 +707,24 @@ Params: closest_idx : index of closest data point found (return) closest_dist : distance to closest point (return) ************************************************/ -void search_tree_${DTYPE}(Tree_${DTYPE} *tree, ${DTYPE} *pa, ${DTYPE} *point_coords, - uint32_t num_points, uint32_t k, ${DTYPE} distance_upper_bound, - ${DTYPE} eps, uint8_t *mask, uint32_t *closest_idxs, ${DTYPE} *closest_dists) +void search_tree_${DTYPE}_${ITYPE}(Tree_${DTYPE}_${ITYPE} *tree, ${DTYPE} *pa, ${DTYPE} *point_coords, + u${ITYPE} num_points, u${ITYPE} k, ${DTYPE} distance_upper_bound, + ${DTYPE} eps, uint8_t *mask, u${ITYPE} *closest_idxs, ${DTYPE} *closest_dists) { ${DTYPE} min_dist; ${DTYPE} eps_fac = 1 / ((1 + eps) * (1 + eps)); int8_t no_dims = tree->no_dims; ${DTYPE} *bbox = tree->bbox; - uint32_t *pidx = tree->pidx; - uint32_t j = 0; + u${ITYPE} *pidx = tree->pidx; + u${ITYPE} j = 0; #if defined(_MSC_VER) && defined(_OPENMP) - int32_t i = 0; - int32_t local_num_points = (int32_t) num_points; + ${ITYPE} i = 0; + ${ITYPE} local_num_points = (${ITYPE}) num_points; #else - uint32_t i; - uint32_t local_num_points = num_points; + u${ITYPE} i; + u${ITYPE} local_num_points = num_points; #endif - Node_${DTYPE} *root = (Node_${DTYPE} *)tree->root; + Node_${DTYPE}_${ITYPE} *root = (Node_${DTYPE}_${ITYPE} *)tree->root; /* Queries are OpenMP enabled */ #pragma omp parallel @@ -722,13 +737,14 @@ void search_tree_${DTYPE}(Tree_${DTYPE} *tree, ${DTYPE} *pa, ${DTYPE} *point_coo { for (j = 0; j < k; j++) { - closest_idxs[i * k + j] = UINT32_MAX; - closest_dists[i * k + j] = DBL_MAX; + closest_idxs[i * k + j] = IDX_MAX_${ITYPE}; + closest_dists[i * k + j] = DIST_MAX_${DTYPE}; } min_dist = get_min_dist_${DTYPE}(point_coords + no_dims * i, no_dims, bbox); - search_splitnode_${DTYPE}(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, + search_splitnode_${DTYPE}_${ITYPE}(root, pa, pidx, no_dims, point_coords + no_dims * i, min_dist, k, distance_upper_bound, eps_fac, mask, &closest_idxs[i * k], &closest_dists[i * k]); } } } % endfor +% endfor diff --git a/pykdtree/kdtree.pyx b/pykdtree/kdtree.pyx index 7104810..5e5d066 100644 --- a/pykdtree/kdtree.pyx +++ b/pykdtree/kdtree.pyx @@ -17,51 +17,91 @@ import numpy as np cimport numpy as np -from libc.stdint cimport uint32_t, int8_t, uint8_t +from libc.stdint cimport uint64_t, uint32_t, int8_t, uint8_t, UINT32_MAX cimport cython np.import_array() # Node structure -cdef struct node_float: +cdef struct node_float_int32_t: float cut_val int8_t cut_dim uint32_t start_idx uint32_t n float cut_bounds_lv float cut_bounds_hv - node_float *left_child - node_float *right_child + node_float_int32_t *left_child + node_float_int32_t *right_child -cdef struct tree_float: +cdef struct tree_float_int32_t: float *bbox int8_t no_dims uint32_t *pidx - node_float *root + node_float_int32_t *root -cdef struct node_double: +cdef struct node_double_int32_t: double cut_val int8_t cut_dim uint32_t start_idx uint32_t n double cut_bounds_lv double cut_bounds_hv - node_double *left_child - node_double *right_child + node_double_int32_t *left_child + node_double_int32_t *right_child -cdef struct tree_double: +cdef struct tree_double_int32_t: double *bbox int8_t no_dims uint32_t *pidx - node_double *root + node_double_int32_t *root -cdef extern tree_float* construct_tree_float(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil -cdef extern void search_tree_float(tree_float *kdtree, float *pa, float *point_coords, uint32_t num_points, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists) nogil -cdef extern void delete_tree_float(tree_float *kdtree) +cdef struct node_float_int64_t: + float cut_val + int8_t cut_dim + uint64_t start_idx + uint64_t n + float cut_bounds_lv + float cut_bounds_hv + node_float_int64_t *left_child + node_float_int64_t *right_child + +cdef struct tree_float_int64_t: + float *bbox + int8_t no_dims + uint64_t *pidx + node_float_int64_t *root -cdef extern tree_double* construct_tree_double(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil -cdef extern void search_tree_double(tree_double *kdtree, double *pa, double *point_coords, uint32_t num_points, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists) nogil -cdef extern void delete_tree_double(tree_double *kdtree) +cdef struct node_double_int64_t: + double cut_val + int8_t cut_dim + uint64_t start_idx + uint64_t n + double cut_bounds_lv + double cut_bounds_hv + node_double_int64_t *left_child + node_double_int64_t *right_child + +cdef struct tree_double_int64_t: + double *bbox + int8_t no_dims + uint64_t *pidx + node_double_int64_t *root + +cdef extern tree_float_int32_t* construct_tree_float_int32_t(float *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil +cdef extern void search_tree_float_int32_t(tree_float_int32_t *kdtree, float *pa, float *point_coords, uint32_t num_points, uint32_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint32_t *closest_idxs, float *closest_dists) nogil +cdef extern void delete_tree_float_int32_t(tree_float_int32_t *kdtree) + +cdef extern tree_double_int32_t* construct_tree_double_int32_t(double *pa, int8_t no_dims, uint32_t n, uint32_t bsp) nogil +cdef extern void search_tree_double_int32_t(tree_double_int32_t *kdtree, double *pa, double *point_coords, uint32_t num_points, uint32_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint32_t *closest_idxs, double *closest_dists) nogil +cdef extern void delete_tree_double_int32_t(tree_double_int32_t *kdtree) + +cdef extern tree_float_int64_t* construct_tree_float_int64_t(float *pa, int8_t no_dims, uint64_t n, uint64_t bsp) nogil +cdef extern void search_tree_float_int64_t(tree_float_int64_t *kdtree, float *pa, float *point_coords, uint64_t num_points, uint64_t k, float distance_upper_bound, float eps_fac, uint8_t *mask, uint64_t *closest_idxs, float *closest_dists) nogil +cdef extern void delete_tree_float_int64_t(tree_float_int64_t *kdtree) + +cdef extern tree_double_int64_t* construct_tree_double_int64_t(double *pa, int8_t no_dims, uint64_t n, uint64_t bsp) nogil +cdef extern void search_tree_double_int64_t(tree_double_int64_t *kdtree, double *pa, double *point_coords, uint64_t num_points, uint64_t k, double distance_upper_bound, double eps_fac, uint8_t *mask, uint64_t *closest_idxs, double *closest_dists) nogil +cdef extern void delete_tree_double_int64_t(tree_double_int64_t *kdtree) cdef class KDTree: """kd-tree for fast nearest-neighbour lookup. @@ -73,23 +113,30 @@ cdef class KDTree: Data points with shape (n , dims) leafsize : int, optional Maximum number of data points in tree leaf + index_bits : int, optional + Number of bits (32 or 64) to use for indexing. Default is 32, use 64 bits if n * k > 2^32. """ - cdef tree_float *_kdtree_float - cdef tree_double *_kdtree_double + cdef tree_float_int32_t *_kdtree_float_int32_t + cdef tree_double_int32_t *_kdtree_double_int32_t + cdef tree_float_int64_t *_kdtree_float_int64_t + cdef tree_double_int64_t *_kdtree_double_int64_t cdef readonly np.ndarray data_pts cdef readonly np.ndarray data cdef float *_data_pts_data_float cdef double *_data_pts_data_double - cdef readonly uint32_t n + cdef readonly uint64_t n cdef readonly int8_t ndim cdef readonly uint32_t leafsize + cdef readonly uint32_t index_bits def __cinit__(KDTree self): - self._kdtree_float = NULL - self._kdtree_double = NULL + self._kdtree_float_int32_t = NULL + self._kdtree_double_int32_t = NULL + self._kdtree_float_int64_t = NULL + self._kdtree_double_int64_t = NULL - def __init__(KDTree self, np.ndarray data_pts not None, int leafsize=16): + def __init__(KDTree self, np.ndarray data_pts not None, int leafsize=16, int index_bits=32): # Check arguments if leafsize < 1: @@ -98,6 +145,8 @@ cdef class KDTree: raise ValueError('data_pts array should have exactly 2 dimensions') if data_pts.size == 0: raise ValueError('data_pts should be non-empty') + if index_bits not in [32, 64]: + raise ValueError('index_bits must be either 32 or 64') # Get data content cdef np.ndarray[float, ndim=1] data_array_float @@ -116,7 +165,10 @@ cdef class KDTree: self.data = self.data_pts # Get tree info - self.n = data_pts.shape[0] + self.index_bits = index_bits + self.n = data_pts.shape[0] + if self.index_bits == 32 and self.n > UINT32_MAX: + raise ValueError('Set index_bits=64 for more than 2^32 data points') self.leafsize = leafsize if data_pts.ndim == 1: self.ndim = 1 @@ -127,13 +179,23 @@ cdef class KDTree: # Release GIL and construct tree if data_pts.dtype == np.float32: - with nogil: - self._kdtree_float = construct_tree_float(self._data_pts_data_float, self.ndim, - self.n, self.leafsize) + if self.index_bits == 32: + with nogil: + self._kdtree_float_int32_t = construct_tree_float_int32_t(self._data_pts_data_float, self.ndim, + self.n, self.leafsize) + else: + with nogil: + self._kdtree_float_int64_t = construct_tree_float_int64_t(self._data_pts_data_float, self.ndim, + self.n, self.leafsize) else: - with nogil: - self._kdtree_double = construct_tree_double(self._data_pts_data_double, self.ndim, - self.n, self.leafsize) + if self.index_bits == 32: + with nogil: + self._kdtree_double_int32_t = construct_tree_double_int32_t(self._data_pts_data_double, self.ndim, + self.n, self.leafsize) + else: + with nogil: + self._kdtree_double_int64_t = construct_tree_double_int64_t(self._data_pts_data_double, self.ndim, + self.n, self.leafsize) def query(KDTree self, np.ndarray query_pts not None, k=1, eps=0, @@ -171,6 +233,8 @@ cdef class KDTree: elif distance_upper_bound is not None: if distance_upper_bound < 0: raise ValueError('distance_upper_bound must be non negative') + elif self.index_bits == 32 and self.n * k > UINT32_MAX: + raise ValueError('Set index_bits=64 for num points * num neighbours greater than 2^32') # Check dimensions if query_pts.ndim == 1: @@ -185,18 +249,27 @@ cdef class KDTree: raise TypeError('Type mismatch. query points must be of type float32 when data points are of type float32') # Get query info - cdef uint32_t num_qpoints = query_pts.shape[0] - cdef uint32_t num_n = k - cdef np.ndarray[uint32_t, ndim=1] closest_idxs = np.empty(num_qpoints * k, dtype=np.uint32) + cdef uint64_t num_qpoints = query_pts.shape[0] + cdef uint64_t num_n = k + cdef np.ndarray[uint32_t, ndim=1] closest_idxs_int32_t + cdef np.ndarray[uint64_t, ndim=1] closest_idxs_int64_t cdef np.ndarray[float, ndim=1] closest_dists_float cdef np.ndarray[double, ndim=1] closest_dists_double - # Set up return arrays - cdef uint32_t *closest_idxs_data = closest_idxs.data + cdef uint32_t *closest_idxs_data_int32_t + cdef uint64_t *closest_idxs_data_int64_t cdef float *closest_dists_data_float cdef double *closest_dists_data_double - + if self.index_bits == 32: + closest_idxs_int32_t = np.empty(num_qpoints * k, dtype=np.uint32) + closest_idxs = closest_idxs_int32_t + closest_idxs_data_int32_t = closest_idxs_int32_t.data + else: + closest_idxs_int64_t = np.empty(num_qpoints * k, dtype=np.uint64) + closest_idxs = closest_idxs_int64_t + closest_idxs_data_int64_t = closest_idxs_int64_t.data + # Get query points data cdef np.ndarray[float, ndim=1] query_array_float cdef np.ndarray[double, ndim=1] query_array_double @@ -247,17 +320,28 @@ cdef class KDTree: # Release GIL and query tree if self.data_pts.dtype == np.float32: - with nogil: - search_tree_float(self._kdtree_float, self._data_pts_data_float, - query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, - query_mask_data, closest_idxs_data, closest_dists_data_float) - + if self.index_bits == 32: + with nogil: + search_tree_float_int32_t(self._kdtree_float_int32_t, self._data_pts_data_float, + query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, + query_mask_data, closest_idxs_data_int32_t, closest_dists_data_float) + else: + with nogil: + search_tree_float_int64_t(self._kdtree_float_int64_t, self._data_pts_data_float, + query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, + query_mask_data, closest_idxs_data_int64_t, closest_dists_data_float) else: - with nogil: - search_tree_double(self._kdtree_double, self._data_pts_data_double, - query_array_data_double, num_qpoints, num_n, dub_double, epsilon_double, - query_mask_data, closest_idxs_data, closest_dists_data_double) - + if self.index_bits == 32: + with nogil: + search_tree_double_int32_t(self._kdtree_double_int32_t, self._data_pts_data_double, + query_array_data_double, num_qpoints, num_n, dub_double, epsilon_double, + query_mask_data, closest_idxs_data_int32_t, closest_dists_data_double) + else: + with nogil: + search_tree_double_int64_t(self._kdtree_double_int64_t, self._data_pts_data_double, + query_array_data_double, num_qpoints, num_n, dub_double, epsilon_double, + query_mask_data, closest_idxs_data_int64_t, closest_dists_data_double) + # Shape result if k > 1: closest_dists_res = closest_dists.reshape(num_qpoints, k) @@ -281,7 +365,11 @@ cdef class KDTree: return closest_dists_res, closest_idxs_res def __dealloc__(KDTree self): - if self._kdtree_float != NULL: - delete_tree_float(self._kdtree_float) - elif self._kdtree_double != NULL: - delete_tree_double(self._kdtree_double) + if self._kdtree_float_int32_t != NULL: + delete_tree_float_int32_t(self._kdtree_float_int32_t) + elif self._kdtree_double_int32_t != NULL: + delete_tree_double_int32_t(self._kdtree_double_int32_t) + if self._kdtree_float_int64_t != NULL: + delete_tree_float_int64_t(self._kdtree_float_int64_t) + elif self._kdtree_double_int64_t != NULL: + delete_tree_double_int64_t(self._kdtree_double_int64_t) diff --git a/pykdtree/test_tree.py b/pykdtree/test_tree.py index 331ee4c..2bac4f6 100644 --- a/pykdtree/test_tree.py +++ b/pykdtree/test_tree.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from pykdtree.kdtree import KDTree @@ -103,17 +104,20 @@ [ 750056.375, -46624.227, 6326519. ], [ 749718.875, -43993.633, 6326578. ]]) -def test1d(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test1d(index_bits): data_pts = np.arange(1000)[..., None] - kdtree = KDTree(data_pts, leafsize=15) + kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) query_pts = np.arange(400, 300, -10)[..., None] dist, idx = kdtree.query(query_pts) assert idx[0] == 400 assert dist[0] == 0 assert idx[1] == 390 + assert idx.dtype.itemsize * 8 == index_bits -def test3d(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d(index_bits): #7, 93, 45 @@ -122,7 +126,7 @@ def test3d(): [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 @@ -132,8 +136,10 @@ def test3d(): assert dist[0] == 0 assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] + assert idx.dtype.itemsize * 8 == index_bits -def test3d_float32(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_float32(index_bits): #7, 93, 45 @@ -142,7 +148,7 @@ def test3d_float32(): [769957.188, -202418.125, 6321069.5]], dtype=np.float32) - kdtree = KDTree(data_pts_real.astype(np.float32)) + kdtree = KDTree(data_pts_real.astype(np.float32), index_bits=index_bits) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 assert idx[0] == 7 @@ -152,8 +158,10 @@ def test3d_float32(): assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] assert kdtree.data_pts.dtype == np.float32 + assert idx.dtype.itemsize * 8 == index_bits -def test3d_float32_mismatch(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_float32_mismatch(index_bits): #7, 93, 45 @@ -161,10 +169,11 @@ def test3d_float32_mismatch(): [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]], dtype=np.float32) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, sqr_dists=True) -def test3d_float32_mismatch2(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_float32_mismatch2(index_bits): #7, 93, 45 @@ -172,20 +181,20 @@ def test3d_float32_mismatch2(): [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real.astype(np.float32)) + kdtree = KDTree(data_pts_real.astype(np.float32), index_bits=index_bits) try: dist, idx = kdtree.query(query_pts, sqr_dists=True) assert False except TypeError: assert True - -def test3d_8n(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_8n(index_bits): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, k=8) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -202,12 +211,13 @@ def test3d_8n(): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -def test3d_8n_ub(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_8n_ub(index_bits): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, k=8, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -224,12 +234,13 @@ def test3d_8n_ub(): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -def test3d_8n_ub_leaf20(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_8n_ub_leaf20(index_bits): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, leafsize=20) + kdtree = KDTree(data_pts_real, leafsize=20, index_bits=index_bits) dist, idx = kdtree.query(query_pts, k=8, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -246,12 +257,13 @@ def test3d_8n_ub_leaf20(): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -def test3d_8n_ub_eps(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_8n_ub_eps(index_bits): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, k=8, eps=0.1, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -268,7 +280,8 @@ def test3d_8n_ub_eps(): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -def test3d_large_query(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_large_query(index_bits): # Target idxs: 7, 93, 45 query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], @@ -278,7 +291,7 @@ def test3d_large_query(): n = 20000 query_pts = np.repeat(query_pts, n, axis=0) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 @@ -289,17 +302,19 @@ def test3d_large_query(): assert np.all(abs(dist[n:2*n] - 3.) < epsilon * dist[n:2*n]) assert np.all(abs(dist[2*n:] - 20001.) < epsilon * dist[2*n:]) -def test_scipy_comp(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test_scipy_comp(index_bits): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real) + kdtree = KDTree(data_pts_real, index_bits=index_bits) assert id(kdtree.data) == id(kdtree.data_pts) -def test1d_mask(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test1d_mask(index_bits): data_pts = np.arange(1000)[..., None] # put the input locations in random order np.random.shuffle(data_pts) @@ -307,7 +322,7 @@ def test1d_mask(): print(bad_idx) nearest_idx_1 = np.nonzero(data_pts[..., 0] == 399) nearest_idx_2 = np.nonzero(data_pts[..., 0] == 390) - kdtree = KDTree(data_pts, leafsize=15) + kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) # shift the query points just a little bit for known neighbors # we want 399 as a result, not 401, when we query for ~400 query_pts = np.arange(399.9, 299.9, -10)[..., None] @@ -320,10 +335,11 @@ def test1d_mask(): assert np.isclose(dist[1], 0.1) -def test1d_all_masked(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test1d_all_masked(index_bits): data_pts = np.arange(1000)[..., None] np.random.shuffle(data_pts) - kdtree = KDTree(data_pts, leafsize=15) + kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) query_pts = np.arange(400, 300, -10)[..., None] query_mask = np.ones(data_pts.shape[0]).astype(bool) dist, idx = kdtree.query(query_pts, mask=query_mask) @@ -332,7 +348,8 @@ def test1d_all_masked(): assert np.all(d >= 1001 for d in dist) -def test3d_mask(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test3d_mask(index_bits): #7, 93, 45 query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], @@ -352,30 +369,33 @@ def test3d_mask(): assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] -def test128d_fail(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test128d_fail(index_bits): pts = 100 dims = 128 data_pts = np.arange(pts * dims).reshape(pts, dims) try: - kdtree = KDTree(data_pts) + kdtree = KDTree(data_pts, index_bits=index_bits) except ValueError as exc: assert "Max 127 dimensions" in str(exc) else: raise Exception("Should not accept 129 dimensional data") -def test127d_ok(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test127d_ok(index_bits): pts = 2 dims = 127 data_pts = np.arange(pts * dims).reshape(pts, dims) - kdtree = KDTree(data_pts) + kdtree = KDTree(data_pts, index_bits=index_bits) dist, idx = kdtree.query(data_pts) assert np.all(dist == 0) -def test_empty_fail(): +@pytest.mark.parametrize("index_bits", [32, 64]) +def test_empty_fail(index_bits): data_pts = np.array([1, 2, 3]) try: - kdtree = KDTree(data_pts) + kdtree = KDTree(data_pts, index_bits=index_bits) except ValueError as e: assert 'exactly 2 dimensions' in str(e), str(e) data_pts = np.array([[]]) @@ -383,3 +403,12 @@ def test_empty_fail(): kdtree = KDTree(data_pts) except ValueError as e: assert 'non-empty' in str(e), str(e) + + +def test_points_k_too_large_for_32bits_fail(): + data_pts = np.zeros((2**16, 2), dtype=np.float32) + kdtree = KDTree(data_pts, index_bits=32) + try: + kdtree.query(data_pts, 2**16) + except ValueError as e: + assert 'greater than 2^32' in str(e), str(e) From 40d312fd236a2be9b1ba09d4b2c8b78bbf86e7bc Mon Sep 17 00:00:00 2001 From: Liam Keegan Date: Mon, 20 Jan 2025 09:22:15 +0100 Subject: [PATCH 2/2] Automatically determine required index type - add `_use_int32_t` boolean to KDTree - `true` if number of points < `UINT32_MAX` - int32 used in calculations - results returned as `np.uint32` - otherwise `false` - int64 used - results returned as `np.uint64` - make `i`, `j` int64 in `search_tree_*_int32_t` functions - avoids expressions like `i*k` overflowing if `i` is close to `UINT32_MAX` - remove `index_bits` argument - add tests - skip test with n>2^32 points by default due to ram/runtime requirements --- pykdtree/_kdtree_core.c.mako | 11 +-- pykdtree/kdtree.pyx | 24 +++---- pykdtree/test_tree.py | 125 ++++++++++++++++++----------------- 3 files changed, 78 insertions(+), 82 deletions(-) diff --git a/pykdtree/_kdtree_core.c.mako b/pykdtree/_kdtree_core.c.mako index dbe31f8..fc921b8 100644 --- a/pykdtree/_kdtree_core.c.mako +++ b/pykdtree/_kdtree_core.c.mako @@ -716,13 +716,14 @@ void search_tree_${DTYPE}_${ITYPE}(Tree_${DTYPE}_${ITYPE} *tree, ${DTYPE} *pa, $ int8_t no_dims = tree->no_dims; ${DTYPE} *bbox = tree->bbox; u${ITYPE} *pidx = tree->pidx; - u${ITYPE} j = 0; #if defined(_MSC_VER) && defined(_OPENMP) - ${ITYPE} i = 0; - ${ITYPE} local_num_points = (${ITYPE}) num_points; + int64_t i = 0; + int64_t j = 0; + int64_t local_num_points = (int64_t) num_points; #else - u${ITYPE} i; - u${ITYPE} local_num_points = num_points; + uint64_t i; + uint64_t j; + uint64_t local_num_points = (uint64_t) num_points; #endif Node_${DTYPE}_${ITYPE} *root = (Node_${DTYPE}_${ITYPE} *)tree->root; diff --git a/pykdtree/kdtree.pyx b/pykdtree/kdtree.pyx index 5e5d066..476155e 100644 --- a/pykdtree/kdtree.pyx +++ b/pykdtree/kdtree.pyx @@ -113,14 +113,13 @@ cdef class KDTree: Data points with shape (n , dims) leafsize : int, optional Maximum number of data points in tree leaf - index_bits : int, optional - Number of bits (32 or 64) to use for indexing. Default is 32, use 64 bits if n * k > 2^32. """ cdef tree_float_int32_t *_kdtree_float_int32_t cdef tree_double_int32_t *_kdtree_double_int32_t cdef tree_float_int64_t *_kdtree_float_int64_t cdef tree_double_int64_t *_kdtree_double_int64_t + cdef readonly bint _use_int32_t cdef readonly np.ndarray data_pts cdef readonly np.ndarray data cdef float *_data_pts_data_float @@ -128,7 +127,6 @@ cdef class KDTree: cdef readonly uint64_t n cdef readonly int8_t ndim cdef readonly uint32_t leafsize - cdef readonly uint32_t index_bits def __cinit__(KDTree self): self._kdtree_float_int32_t = NULL @@ -136,7 +134,7 @@ cdef class KDTree: self._kdtree_float_int64_t = NULL self._kdtree_double_int64_t = NULL - def __init__(KDTree self, np.ndarray data_pts not None, int leafsize=16, int index_bits=32): + def __init__(KDTree self, np.ndarray data_pts not None, int leafsize=16): # Check arguments if leafsize < 1: @@ -145,8 +143,6 @@ cdef class KDTree: raise ValueError('data_pts array should have exactly 2 dimensions') if data_pts.size == 0: raise ValueError('data_pts should be non-empty') - if index_bits not in [32, 64]: - raise ValueError('index_bits must be either 32 or 64') # Get data content cdef np.ndarray[float, ndim=1] data_array_float @@ -165,10 +161,8 @@ cdef class KDTree: self.data = self.data_pts # Get tree info - self.index_bits = index_bits self.n = data_pts.shape[0] - if self.index_bits == 32 and self.n > UINT32_MAX: - raise ValueError('Set index_bits=64 for more than 2^32 data points') + self._use_int32_t = self.n < UINT32_MAX self.leafsize = leafsize if data_pts.ndim == 1: self.ndim = 1 @@ -179,7 +173,7 @@ cdef class KDTree: # Release GIL and construct tree if data_pts.dtype == np.float32: - if self.index_bits == 32: + if self._use_int32_t: with nogil: self._kdtree_float_int32_t = construct_tree_float_int32_t(self._data_pts_data_float, self.ndim, self.n, self.leafsize) @@ -188,7 +182,7 @@ cdef class KDTree: self._kdtree_float_int64_t = construct_tree_float_int64_t(self._data_pts_data_float, self.ndim, self.n, self.leafsize) else: - if self.index_bits == 32: + if self._use_int32_t: with nogil: self._kdtree_double_int32_t = construct_tree_double_int32_t(self._data_pts_data_double, self.ndim, self.n, self.leafsize) @@ -233,8 +227,6 @@ cdef class KDTree: elif distance_upper_bound is not None: if distance_upper_bound < 0: raise ValueError('distance_upper_bound must be non negative') - elif self.index_bits == 32 and self.n * k > UINT32_MAX: - raise ValueError('Set index_bits=64 for num points * num neighbours greater than 2^32') # Check dimensions if query_pts.ndim == 1: @@ -261,7 +253,7 @@ cdef class KDTree: cdef uint64_t *closest_idxs_data_int64_t cdef float *closest_dists_data_float cdef double *closest_dists_data_double - if self.index_bits == 32: + if self._use_int32_t: closest_idxs_int32_t = np.empty(num_qpoints * k, dtype=np.uint32) closest_idxs = closest_idxs_int32_t closest_idxs_data_int32_t = closest_idxs_int32_t.data @@ -320,7 +312,7 @@ cdef class KDTree: # Release GIL and query tree if self.data_pts.dtype == np.float32: - if self.index_bits == 32: + if self._use_int32_t: with nogil: search_tree_float_int32_t(self._kdtree_float_int32_t, self._data_pts_data_float, query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, @@ -331,7 +323,7 @@ cdef class KDTree: query_array_data_float, num_qpoints, num_n, dub_float, epsilon_float, query_mask_data, closest_idxs_data_int64_t, closest_dists_data_float) else: - if self.index_bits == 32: + if self._use_int32_t: with nogil: search_tree_double_int32_t(self._kdtree_double_int32_t, self._data_pts_data_double, query_array_data_double, num_qpoints, num_n, dub_double, epsilon_double, diff --git a/pykdtree/test_tree.py b/pykdtree/test_tree.py index 2bac4f6..0ef83ce 100644 --- a/pykdtree/test_tree.py +++ b/pykdtree/test_tree.py @@ -104,20 +104,17 @@ [ 750056.375, -46624.227, 6326519. ], [ 749718.875, -43993.633, 6326578. ]]) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test1d(index_bits): +def test1d(): data_pts = np.arange(1000)[..., None] - kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) + kdtree = KDTree(data_pts, leafsize=15) query_pts = np.arange(400, 300, -10)[..., None] dist, idx = kdtree.query(query_pts) assert idx[0] == 400 assert dist[0] == 0 assert idx[1] == 390 - assert idx.dtype.itemsize * 8 == index_bits -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d(index_bits): +def test3d(): #7, 93, 45 @@ -126,7 +123,7 @@ def test3d(index_bits): [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 @@ -136,10 +133,8 @@ def test3d(index_bits): assert dist[0] == 0 assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] - assert idx.dtype.itemsize * 8 == index_bits -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_float32(index_bits): +def test3d_float32(): #7, 93, 45 @@ -148,7 +143,7 @@ def test3d_float32(index_bits): [769957.188, -202418.125, 6321069.5]], dtype=np.float32) - kdtree = KDTree(data_pts_real.astype(np.float32), index_bits=index_bits) + kdtree = KDTree(data_pts_real.astype(np.float32)) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 assert idx[0] == 7 @@ -158,10 +153,8 @@ def test3d_float32(index_bits): assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] assert kdtree.data_pts.dtype == np.float32 - assert idx.dtype.itemsize * 8 == index_bits -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_float32_mismatch(index_bits): +def test3d_float32_mismatch(): #7, 93, 45 @@ -169,11 +162,10 @@ def test3d_float32_mismatch(index_bits): [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]], dtype=np.float32) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, sqr_dists=True) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_float32_mismatch2(index_bits): +def test3d_float32_mismatch2(): #7, 93, 45 @@ -181,20 +173,20 @@ def test3d_float32_mismatch2(index_bits): [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real.astype(np.float32), index_bits=index_bits) + kdtree = KDTree(data_pts_real.astype(np.float32)) try: dist, idx = kdtree.query(query_pts, sqr_dists=True) assert False except TypeError: assert True -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_8n(index_bits): + +def test3d_8n(): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, k=8) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -211,13 +203,12 @@ def test3d_8n(index_bits): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_8n_ub(index_bits): +def test3d_8n_ub(): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, k=8, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -234,13 +225,12 @@ def test3d_8n_ub(index_bits): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_8n_ub_leaf20(index_bits): +def test3d_8n_ub_leaf20(): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, leafsize=20, index_bits=index_bits) + kdtree = KDTree(data_pts_real, leafsize=20) dist, idx = kdtree.query(query_pts, k=8, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -257,13 +247,12 @@ def test3d_8n_ub_leaf20(index_bits): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_8n_ub_eps(index_bits): +def test3d_8n_ub_eps(): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, k=8, eps=0.1, distance_upper_bound=10e3, sqr_dists=False) exp_dist = np.array([[ 0.00000000e+00, 4.05250235e+03, 4.07389794e+03, 8.08201128e+03, @@ -280,8 +269,7 @@ def test3d_8n_ub_eps(index_bits): assert np.array_equal(idx, exp_idx) assert np.allclose(dist, exp_dist) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_large_query(index_bits): +def test3d_large_query(): # Target idxs: 7, 93, 45 query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], @@ -291,7 +279,7 @@ def test3d_large_query(index_bits): n = 20000 query_pts = np.repeat(query_pts, n, axis=0) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) dist, idx = kdtree.query(query_pts, sqr_dists=True) epsilon = 1e-5 @@ -302,19 +290,17 @@ def test3d_large_query(index_bits): assert np.all(abs(dist[n:2*n] - 3.) < epsilon * dist[n:2*n]) assert np.all(abs(dist[2*n:] - 20001.) < epsilon * dist[2*n:]) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test_scipy_comp(index_bits): +def test_scipy_comp(): query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], [769957.188, -202418.125, 6321069.5]]) - kdtree = KDTree(data_pts_real, index_bits=index_bits) + kdtree = KDTree(data_pts_real) assert id(kdtree.data) == id(kdtree.data_pts) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test1d_mask(index_bits): +def test1d_mask(): data_pts = np.arange(1000)[..., None] # put the input locations in random order np.random.shuffle(data_pts) @@ -322,7 +308,7 @@ def test1d_mask(index_bits): print(bad_idx) nearest_idx_1 = np.nonzero(data_pts[..., 0] == 399) nearest_idx_2 = np.nonzero(data_pts[..., 0] == 390) - kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) + kdtree = KDTree(data_pts, leafsize=15) # shift the query points just a little bit for known neighbors # we want 399 as a result, not 401, when we query for ~400 query_pts = np.arange(399.9, 299.9, -10)[..., None] @@ -335,11 +321,10 @@ def test1d_mask(index_bits): assert np.isclose(dist[1], 0.1) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test1d_all_masked(index_bits): +def test1d_all_masked(): data_pts = np.arange(1000)[..., None] np.random.shuffle(data_pts) - kdtree = KDTree(data_pts, leafsize=15, index_bits=index_bits) + kdtree = KDTree(data_pts, leafsize=15) query_pts = np.arange(400, 300, -10)[..., None] query_mask = np.ones(data_pts.shape[0]).astype(bool) dist, idx = kdtree.query(query_pts, mask=query_mask) @@ -348,8 +333,7 @@ def test1d_all_masked(index_bits): assert np.all(d >= 1001 for d in dist) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test3d_mask(index_bits): +def test3d_mask(): #7, 93, 45 query_pts = np.array([[ 787014.438, -340616.906, 6313018.], [751763.125, -59925.969, 6326205.5], @@ -369,33 +353,30 @@ def test3d_mask(index_bits): assert abs(dist[1] - 3.) < epsilon * dist[1] assert abs(dist[2] - 20001.) < epsilon * dist[2] -@pytest.mark.parametrize("index_bits", [32, 64]) -def test128d_fail(index_bits): +def test128d_fail(): pts = 100 dims = 128 data_pts = np.arange(pts * dims).reshape(pts, dims) try: - kdtree = KDTree(data_pts, index_bits=index_bits) + kdtree = KDTree(data_pts) except ValueError as exc: assert "Max 127 dimensions" in str(exc) else: raise Exception("Should not accept 129 dimensional data") -@pytest.mark.parametrize("index_bits", [32, 64]) -def test127d_ok(index_bits): +def test127d_ok(): pts = 2 dims = 127 data_pts = np.arange(pts * dims).reshape(pts, dims) - kdtree = KDTree(data_pts, index_bits=index_bits) + kdtree = KDTree(data_pts) dist, idx = kdtree.query(data_pts) assert np.all(dist == 0) -@pytest.mark.parametrize("index_bits", [32, 64]) -def test_empty_fail(index_bits): +def test_empty_fail(): data_pts = np.array([1, 2, 3]) try: - kdtree = KDTree(data_pts, index_bits=index_bits) + kdtree = KDTree(data_pts) except ValueError as e: assert 'exactly 2 dimensions' in str(e), str(e) data_pts = np.array([[]]) @@ -404,11 +385,33 @@ def test_empty_fail(index_bits): except ValueError as e: assert 'non-empty' in str(e), str(e) - -def test_points_k_too_large_for_32bits_fail(): - data_pts = np.zeros((2**16, 2), dtype=np.float32) - kdtree = KDTree(data_pts, index_bits=32) - try: - kdtree.query(data_pts, 2**16) - except ValueError as e: - assert 'greater than 2^32' in str(e), str(e) +def test_tree_n_lt_maxint32_nk_gt_maxint32(): + # n < UINT32_MAX but n * k > UINT32_MAX -> still uses 32-bit index + data_pts = np.random.random((2**20, 2)).astype(np.float32) + query_pts = np.random.random((3, 2)).astype(np.float32) + data_pts[0] = query_pts[0] + data_pts[1533] = query_pts[1] + data_pts[1048575] = query_pts[2] + kdtree = KDTree(data_pts) + dist, idx = kdtree.query(query_pts, k=2**14) + assert idx.shape == (3, 2**14) + assert idx.dtype == np.uint32 + assert idx[0][0] == 0 + assert idx[1][0] == 1533 + assert idx[2][0] == 1048575 + +@pytest.mark.skip(reason="Requires ~100G RAM, takes ~30mins to run") +def test_tree_n_gt_maxint32(): + # n > UINT32_MAX -> requires 64-bit index + data_pts = np.random.random((2**32 + 8, 2)).astype(np.float32) + query_pts = np.random.random((3, 2)).astype(np.float32) + data_pts[0] = query_pts[0] + data_pts[874516] = query_pts[1] + data_pts[4294967300] = query_pts[2] + kdtree = KDTree(data_pts) + dist, idx = kdtree.query(query_pts, k=12) + assert idx.shape == (3, 12) + assert idx.dtype == np.uint64 + assert idx[0][0] == 0 + assert idx[1][0] == 874516 + assert idx[2][0] == 4294967300