Skip to content

Commit

Permalink
add support for linear probing
Browse files Browse the repository at this point in the history
  • Loading branch information
smroid committed Feb 22, 2025
1 parent 89f0c04 commit ea4d6db
Showing 1 changed file with 93 additions and 39 deletions.
132 changes: 93 additions & 39 deletions tetra3/tetra3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,50 +146,84 @@
_supported_databases = ('bsc5', 'hip_main', 'tyc_main')
_lib_root = Path(__file__).parent

def _insert_at_index(pattern, hash_index, table):
"""Inserts to table with quadratic probing. Returns table index where pattern was inserted."""
def _is_prime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
# Only check odd numbers up to sqrt(n)
for i in range(3, int(n ** 0.5) + 1, 2):
if n % i == 0:
return False
return True

def _next_prime(n):
if n < 2:
return 2
n = n + 1 + (n % 2) # Next odd number after n
while not _is_prime(n):
n += 2 # Skip even numbers
return n

def _insert_at_index(pattern, hash_index, table, linear_probe):
"""Inserts to table with quadratic or linear probing. Returns table index where
pattern was inserted."""
max_ind = np.uint64(table.shape[0])
hash_index = np.uint64(hash_index)
for c in itertools.count():
c = np.uint64(c)
i = (hash_index + c*c) % max_ind
if linear_probe:
i = (hash_index + c) % max_ind
else:
i = (hash_index + c*c) % max_ind
if all(table[i, :] == 0):
table[i, :] = pattern
return i

def _get_table_indices_from_hash(hash_index, table):
"""Gets from table with quadratic probing, returns list of all possibly matching indices."""
def _get_table_indices_from_hash(hash_index, table, linear_probe):
"""Gets from table with quadratic or linear probing, returns list of all
possibly matching indices."""
max_ind = np.uint64(table.shape[0])
hash_index = np.uint64(hash_index)
found = []
for c in itertools.count():
c = np.uint64(c)
i = (hash_index + c*c) % max_ind
if linear_probe:
i = (hash_index + c) % max_ind
else:
i = (hash_index + c*c) % max_ind
if all(table[i, :] == 0):
return np.array(found)
else:
found.append(i)

def _pattern_hash_to_index(pattern_hash, bin_factor, max_index):
def _pattern_hash_to_index(pattern_hash, bin_factor, max_index, linear_probe):
"""Get hash index for a given pattern_hash (tuple of ordered binned edge ratios).
Can be length p list or n by p array."""
pattern_hash = np.uint64(pattern_hash)
bin_factor = np.uint64(bin_factor)
max_index = np.uint64(max_index)
# Combine pattern_hash components to a single large number.
# If p is the length of the pattern_hash (default 5) and B is the number of bins (default 50,
# calculated from max error), this will first give each pattern_hash a unique index from
# 0 to B^p-1, then multiply by large number and modulo to max index to randomise.

# If p is the length of the pattern_hash (default 5) and B is the number of bins
# (default 50, calculated from max error), this will first give each pattern_hash
# a unique index from 0 to B^p-1.
if pattern_hash.ndim == 1:
combined = np.sum(pattern_hash*bin_factor**np.arange(len(pattern_hash), dtype=np.uint64),
combined = np.sum(pattern_hash*bin_factor**np.arange(len(pattern_hash),
dtype=np.uint64),
dtype=np.uint64)
else:
combined = np.sum(pattern_hash*bin_factor**np.arange(pattern_hash.shape[1],
dtype=np.uint64)[None, :],
axis=1, dtype=np.uint64)
with np.errstate(over='ignore'):
combined = (combined*_MAGIC_RAND) % max_index
return combined
if linear_probe:
return combined % max_index
else:
# For legacy compability.
with np.errstate(over='ignore'):
return (combined*_MAGIC_RAND) % max_index

def _compute_vectors(centroids, size, fov):
"""Get unit vectors from star centroids (pinhole camera)."""
Expand Down Expand Up @@ -486,7 +520,7 @@ def database_properties(self):
Keys:
- 'pattern_mode': Method used to identify star patterns. Is always 'edge_ratio'.
- 'hash_table_type': What algorithm is used for the pattern hash table. The only
value (currently) is 'quadratic_probe'.
values (currently) are 'quadratic_probe' and 'linear_probe'.
- 'pattern_size': Number of stars in each pattern.
- 'pattern_bins': Number of bins per dimension in pattern catalog.
- 'pattern_max_error': Maximum difference allowed in pattern for a match.
Expand Down Expand Up @@ -850,7 +884,7 @@ def generate_database(self, max_fov, min_fov=None, save_as=None,
verification_stars_per_fov=150, star_max_magnitude=None,
pattern_max_error=.001,
multiscale_step=1.5, epoch_proper_motion='now',
pattern_stars_per_fov=None, legacy_mode=False):
pattern_stars_per_fov=None, linear_probe=False):
"""Create a database and optionally save it to file.
Takes a few minutes for a small (large FOV) database, can take many hours for a large
Expand Down Expand Up @@ -954,9 +988,10 @@ def generate_database(self, max_fov, min_fov=None, save_as=None,
Args:
max_fov (float): Maximum angle (in degrees) between stars in the same pattern.
min_fov (float, optional): Minimum FOV considered when the catalogue density is trimmed to size.
If None (the default), min_fov will be set to max_fov, i.e. a catalogue for a single
application is generated (this is most efficient size and speed wise).
min_fov (float, optional): Minimum FOV considered when the catalogue density is
trimmed to size. If None (the default), min_fov will be set to max_fov, i.e.
a catalogue for a single application is generated (this is most efficient size
and speed wise).
save_as (str or pathlib.Path, optional): Save catalogue here when finished. Calls
:meth:`save_database`.
star_catalog (string, optional): Abbreviated name of star catalog, one of 'bsc5',
Expand All @@ -975,8 +1010,8 @@ def generate_database(self, max_fov, min_fov=None, save_as=None,
pattern_max_error (float, optional): This value determines the number of bins into which
a pattern hash's edge ratios are each quantized:
pattern_bins = 0.25 / pattern_max_error
Default .001, corresponding to pattern_bins=250. For a database with limiting magnitude
7, this yields a reasonable pattern hash collision rate.
Default .001, corresponding to pattern_bins=250. For a database with limiting
magnitude 7, this yields a reasonable pattern hash collision rate.
multiscale_step (float, optional): Determines the largest ratio between subsequent FOVs
that is allowed when generating a multiscale database. Defaults to 1.5. If the ratio
max_fov/min_fov is less than sqrt(multiscale_step) a single scale database is built.
Expand All @@ -986,15 +1021,17 @@ def generate_database(self, max_fov, min_fov=None, save_as=None,
without proper motions to be used in the database.
pattern_stars_per_fov (int, optional): Deprecated. If given, is used instead of
`lattice_field_oversampling`, which has similar values.
legacy_mode (bool, optional): If True, uses 'quadratic_probe' for 'hash_table_type',
for compatibility with earlier versions of Tetra3. For new usages leave this as
False, to enable improved 'hash_table_type'.
linear_probe (bool, optional): If False (default), uses quadratic probing in the
hash table. This is appropriate for deployments where you expect the pattern
database to fit entirely in RAM. Use linear_probe=True when you expect the
pattern database to be too large to fit in RAM.
"""
self._logger.debug('Got generate pattern catalogue with input: '
+ str((max_fov, min_fov, save_as, star_catalog, lattice_field_oversampling,
+ str((max_fov, min_fov, save_as, star_catalog,
lattice_field_oversampling,
patterns_per_lattice_field, verification_stars_per_fov,
star_max_magnitude, pattern_max_error,
multiscale_step, epoch_proper_motion, legacy_mode)))
multiscale_step, epoch_proper_motion, linear_probe)))
if pattern_stars_per_fov is not None and pattern_stars_per_fov != lattice_field_oversampling:
self._logger.warning(
'pattern_stars_per_fov value %s is overriding lattice_field_oversampling value %s' %
Expand Down Expand Up @@ -1022,6 +1059,7 @@ def generate_database(self, max_fov, min_fov=None, save_as=None,

patterns_per_lattice_field = int(patterns_per_lattice_field)
verification_stars_per_fov = int(verification_stars_per_fov)
linear_probe = bool(linear_probe)
if star_max_magnitude is not None:
star_max_magnitude = float(star_max_magnitude)
PATTERN_SIZE = 4
Expand Down Expand Up @@ -1249,7 +1287,10 @@ def logk(x, k):
# Create all pattern hashes by calculating, sorting, and binning edge ratios; then compute
# a table index hash from the pattern hash, and store the table index -> pattern mapping.
self._logger.info('Start building catalogue.')
catalog_length = int(3 * len(pattern_list))
if linear_probe:
catalog_length = int(_next_prime(3 * len(pattern_list)))
else:
catalog_length = int(_next_prime(2 * len(pattern_list)))
# Determine type to make sure the biggest index will fit, create pattern catalogue
max_index = np.max(np.array(pattern_list))
if max_index <= np.iinfo('uint8').max:
Expand Down Expand Up @@ -1284,7 +1325,8 @@ def logk(x, k):

# convert edge ratio float to pattern hash by binning
pattern_hash = [int(ratio * pattern_bins) for ratio in edge_ratios]
hash_index = _pattern_hash_to_index(pattern_hash, pattern_bins, catalog_length)
hash_index = _pattern_hash_to_index(
pattern_hash, pattern_bins, catalog_length, linear_probe)

if EVALUATE_COLLISIONS:
prev_len = len(pattern_hashes_seen)
Expand All @@ -1306,31 +1348,39 @@ def logk(x, k):
# Use the radii to uniquely order the pattern, used for future matching.
pattern = [pattern[i] for (_, i) in centroid_distances]

index = _insert_at_index(pattern, hash_index, pattern_catalog)
index = _insert_at_index(pattern, hash_index, pattern_catalog, linear_probe)
# Store as milliradian to better use float16 range
pattern_largest_edge[index] = largest_angle*1000

total_probes = 0
max_probes = 0
if EVALUATE_COLLISIONS:
# Evaluate average hash table probe count.
for pattern_hash in pattern_hashes_seen:
hash_index = _pattern_hash_to_index(pattern_hash, pattern_bins, catalog_length)
hash_match_inds = _get_table_indices_from_hash(hash_index, pattern_catalog)
total_probes += len(hash_match_inds)
hash_index = _pattern_hash_to_index(
pattern_hash, pattern_bins, catalog_length, linear_probe)
hash_match_inds = _get_table_indices_from_hash(
hash_index, pattern_catalog, linear_probe)
probes = len(hash_match_inds)
total_probes += probes
if probes > max_probes:
max_probes = probes

self._logger.info('Finished generating database.')
self._logger.info('Size of uncompressed star table: %i Bytes.' %star_table.nbytes)
self._logger.info('Size of uncompressed pattern catalog: %i Bytes.' %pattern_catalog.nbytes)
if EVALUATE_COLLISIONS:
self._logger.info('Pattern hash collisions: %s; average table probe len: %.2f'
% (pattern_hash_collisions, total_probes / len(pattern_hashes_seen)))
self._logger.info('Pattern hash collisions: %s; average/max table probe len: %.2f/%d'
% (pattern_hash_collisions,
total_probes / len(pattern_hashes_seen),
max_probes))
self._star_table = star_table
self._star_kd_tree = vector_kd_tree
self._star_catalog_IDs = star_catID
self._pattern_catalog = pattern_catalog
self._pattern_largest_edge = pattern_largest_edge
self._db_props['pattern_mode'] = 'edge_ratio'
self._db_props['hash_table_type'] = 'quadratic_probe'
self._db_props['hash_table_type'] = 'linear_probe' if linear_probe else 'quadratic_probe'
self._db_props['pattern_size'] = PATTERN_SIZE
self._db_props['pattern_bins'] = pattern_bins
self._db_props['pattern_max_error'] = pattern_max_error
Expand Down Expand Up @@ -1665,6 +1715,7 @@ def solve_from_centroids(self, star_centroids, size, fov_estimate=None, fov_max_
match_max_error = self._db_props['pattern_max_error']
p_max_err = match_max_error
presorted = self._db_props['presort_patterns']
linear_probe = self._db_props['hash_table_type'] == 'linear_probe'

# Indices to extract from dot product matrix (above diagonal)
upper_tri_index = np.triu_indices(p_size, 1)
Expand Down Expand Up @@ -1779,12 +1830,13 @@ def dist(pattern_hash):
for (_, pattern_hash) in pattern_hash_list:
search_space_explored += 1
# Calculate corresponding hash index.
hash_index = _pattern_hash_to_index(pattern_hash, p_bins, self.pattern_catalog.shape[0])
hash_index = _pattern_hash_to_index(
pattern_hash, p_bins, self.pattern_catalog.shape[0], linear_probe)

(catalog_pattern_edges, all_catalog_pattern_vectors) = \
self._get_all_patterns_for_index(
hash_index, upper_tri_index, image_pattern_largest_edge,
fov_estimate, fov_max_error)
fov_estimate, fov_max_error, linear_probe)
if catalog_pattern_edges is None:
continue
catalog_lookup_count += len(catalog_pattern_edges)
Expand Down Expand Up @@ -2151,11 +2203,13 @@ def cancel_solve(self):
self._cancelled = True

def _get_all_patterns_for_index(self, hash_index, upper_tri_index,
image_pattern_largest_edge, fov_estimate, fov_max_error):
image_pattern_largest_edge, fov_estimate, fov_max_error,
linear_probe):
"""Returns (edges, vectors) for all pattern table entries for `hash_index`."""

# Iterate over table hash indices.
hash_match_inds = _get_table_indices_from_hash(hash_index, self.pattern_catalog)
hash_match_inds = _get_table_indices_from_hash(
hash_index, self.pattern_catalog, linear_probe)
if len(hash_match_inds) == 0:
return (None, None)

Expand Down

0 comments on commit ea4d6db

Please sign in to comment.