Skip to content

Commit

Permalink
add install_db
Browse files Browse the repository at this point in the history
  • Loading branch information
TAKASE Yusuke committed Sep 17, 2024
1 parent 1e96046 commit 29231df
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 119 deletions.
51 changes: 30 additions & 21 deletions notebooks/diff_gain_boresight.ipynb

Large diffs are not rendered by default.

39 changes: 13 additions & 26 deletions notebooks/diff_gain_channel.ipynb

Large diffs are not rendered by default.

47 changes: 19 additions & 28 deletions notebooks/diff_pointing_boresight.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions notebooks/diff_pointing_channel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@
],
"source": [
"# Load single detector map\n",
"base_path = \"../../../maps/crosslink_maps/crosslinks_2407/nside_128\"\n",
"data = ScanFields()\n",
"ch_id = 0\n",
"channel = data.all_channels[ch_id]\n",
"sf_ch = ScanFields.load_channel(base_path, channel)\n",
"\n",
"# It assumes that base_path has been installed to SBM by install_db()\n",
"sf_ch = ScanFields.load_channel(channel)\n",
"print(\"Channel: \", channel)"
]
},
Expand Down
50 changes: 29 additions & 21 deletions notebooks/load_crosslink.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/noise.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"source": [
"# Load single detector map\n",
"base_path = \"../maps\"\n",
"scan_field = ScanFields.load_det(base_path, \"nside128_boresight\")"
"scan_field = ScanFields.load_det(\"nside128_boresight\", base_path)"
]
},
{
Expand Down
6 changes: 4 additions & 2 deletions sbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -*- encoding: utf-8 -*-

from .sbm import (
from .main import (
Field,
SignalFields,
ScanFields,
plot_maps,
get_instrument_table,
DB_ROOT_PATH,
)


Expand All @@ -19,10 +20,11 @@
# version.py
"__author__",
"__version__",
# sbm.py
# main.py
"Field",
"SignalFields",
"ScanFields",
"plot_maps",
"get_instrument_table",
"DB_ROOT_PATH",
]
Binary file modified sbm/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified sbm/__pycache__/sbm.cpython-310.pyc
Binary file not shown.
Binary file modified sbm/__pycache__/version.cpython-310.pyc
Binary file not shown.
145 changes: 145 additions & 0 deletions sbm/install_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-

from base64 import b64decode
import getpass
from pathlib import Path
from time import sleep
from github import Github
from rich import print
from rich.table import Table
import tomlkit
import os
import json
import numpy as np
import toml

CONFIG_PATH = Path.home() / ".config" / "sbm_dataset"
CONFIG_FILE_PATH = CONFIG_PATH / "sbm_dataset.toml"

repositories = []


# Convert numpy.int64 to Python int
def custom_encoder(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.int64):
return int(obj)
if isinstance(obj, bytes):
return obj.decode('utf-8')
raise TypeError(f"Type {type(obj)} not serializable")

def retrieve_local_source():
print()
path = Path(
input('Please enter the directory where file "sim_config.json" resides: ')
).absolute()

if not (path / "sim_config.json").is_file():
print(f'[red]Error:[/red] {path} does not seem to contain a "sim_config.json" file')
create_file = input('Would you like to create a "sim_config.json" file? (y/n): ')
if create_file.lower() == 'y':
gen_jsonfile(path)
print(f'[green]"sim_config.json" has been created at {path}.[/green]')
else:
return

name = input("Now insert a descriptive name for this location: ")

repositories.append({"name": name, "location": str(path.resolve())})

print(
f"""
[green]Repository "{name}" has been added successfully.[/green]
"""
)

def run_main_loop() -> bool:
prompt = """Choose a source for the database:
1. [cyan]Local source[/cyan]
A directory on your hard disk.
s. [cyan]Save and quit[/cyan]
q. [cyan]Discard modifications and quit[/cyan]
"""

while True:
print(prompt)
choice = input("Pick your choice (1, s or q): ").strip()

if choice == "1":
retrieve_local_source()
elif choice in ("s", "S"):
print(
"""
Saving changes and quitting...
"""
)
return True

elif choice in ("q", "Q"):
print(
"""
Discarding any change and quitting...
"""
)
return False

sleep(2)


def write_toml_configuration():
file_path = CONFIG_FILE_PATH

# Create the directory containing the file, if it does not exist.
file_path.parent.mkdir(parents=True, exist_ok=True)

with file_path.open("wt") as outf:
outf.write(tomlkit.dumps({"repositories": repositories}))

print(
f"""
The configuration has been saved into file
"{str(file_path)}"
"""
)

def extract_location_from_toml(file_path):
with open(file_path, 'r') as file:
data = toml.load(file)
loc = data['repositories'][0]['location']
return loc

def main():
if run_main_loop():
write_toml_configuration()
if len(repositories) > 0:
print("The following repositories have been configured successfully:")

table = Table()
table.add_column("Name")
table.add_column("Location")

for row in repositories:
table.add_row(row["name"], row["location"])

print(table)

else:
print("No repositories have been configured")

else:
print("Changes have been discarded")


if __name__ == "__main__":
main()
71 changes: 54 additions & 17 deletions sbm/sbm.py → sbm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
import pandas as pd
import litebird_sim as lbs
from litebird_sim import Imo
import toml
from pathlib import Path
from .install_db import extract_location_from_toml

CONFIG_PATH = Path.home() / ".config" / "sbm_dataset"
CONFIG_FILE_PATH = CONFIG_PATH / "sbm_dataset.toml"
DB_ROOT_PATH = extract_location_from_toml(CONFIG_FILE_PATH)

class Field:
""" Class to store the field data of detectors """
def __init__(self, field: np.ndarray, spin: int):
Expand Down Expand Up @@ -113,12 +121,10 @@ def __init__(self):
self.fwhms = [70.5,58.5,51.1,41.6,47.1,36.9,43.8,33.0,41.5,30.2,26.3,23.7,37.8,33.6,30.8,28.9,28.0,28.6,24.7,22.5,20.9,17.9]

@classmethod
def load_det(cls, base_path: str, det_name: str):
def load_det(cls, det_name: str, base_path=DB_ROOT_PATH):
""" Load the scan fields data of a detector from a .h5 file
Args:
base_path (str): path to the directory containing the .h5 file
filename (str): name of the .h5 file containing the scan fields data simulated by Falcons.jl
The fileformat requires cross-link_2407-dataset's format.
The file should contain the following groups:
Expand All @@ -129,6 +135,9 @@ def load_det(cls, base_path: str, det_name: str):
- n: number of spins
- mean: mean of the hitmap and h
- std: standard deviation of the hitmap and h
base_path (str): path to the directory containing the .h5 file
Returns:
instance (ScanFields): instance of the ScanFields class containing the scan fields data of the detector
"""
Expand All @@ -141,6 +150,7 @@ def load_det(cls, base_path: str, det_name: str):
t2b = True
det_name = det_name[:-1] + "T"
filename = det_name + ".h5"

with h5py.File(os.path.join(base_path, filename), 'r') as f:
instance.ss = {key: value[()] for key, value in zip(f['ss'].keys(), f['ss'].values()) if key != "quat"}
instance.hitmap = f['hitmap'][:]
Expand All @@ -156,7 +166,7 @@ def load_det(cls, base_path: str, det_name: str):
return instance

@classmethod
def load_channel(cls, base_path: str, channel: str):
def load_channel(cls, channel: str, base_path=DB_ROOT_PATH):
"""Load the scan fields data of a channel from the directory containing the .h5 files
Args:
Expand All @@ -170,7 +180,7 @@ def load_channel(cls, base_path: str, channel: str):
dirpath = os.path.join(base_path, channel)
filenames = os.listdir(dirpath)
filenames = [os.path.splitext(filename)[0] for filename in filenames]
first_sf = cls.load_det(dirpath, filenames[0])
first_sf = cls.load_det(filenames[0], base_path=dirpath)
instance = cls()
instance.channel = channel
instance.ndet = len(filenames)
Expand All @@ -181,7 +191,7 @@ def load_channel(cls, base_path: str, channel: str):
instance.duration = first_sf.duration
instance.sampling_rate = first_sf.sampling_rate
for filename in filenames:
sf = cls.load_det(dirpath, filename)
sf = cls.load_det(filename, base_path=dirpath)
instance.hitmap += sf.hitmap
instance.h += sf.hitmap[:, np.newaxis] * sf.h
instance.h /= instance.hitmap[:, np.newaxis]
Expand All @@ -191,14 +201,14 @@ def load_channel(cls, base_path: str, channel: str):
@classmethod
def _load_channel_task(cls, args):
base_path, ch = args
return cls.load_channel(base_path, ch)
return cls.load_channel(ch, base_path)

@classmethod
def load_full_FPU(cls, base_path: str, channel_list: list, max_workers=None):
def load_full_FPU(cls, channel_list: list, base_path=DB_ROOT_PATH, max_workers=None):
""" Load the scan fields data of all the channels in the FPU from the directory containing the .h5 files
Args:
base_path (str): path to the directory containing the .h5 files
base_path (str): path to the directory containing the channel's data
channel_list (list): list of channels to load the scan fields data from
Expand Down Expand Up @@ -352,25 +362,25 @@ def diff_gain_field(gain_a, gain_b, I, P):
@classmethod
def sim_diff_gain_channel(
cls,
base_path: str,
channel: str,
mdim: int,
input_map: np.ndarray,
gain_a: np.ndarray,
gain_b: np.ndarray
gain_b: np.ndarray,
base_path=DB_ROOT_PATH,
):
dirpath = os.path.join(base_path, channel)
filenames = os.listdir(dirpath)
filenames = [os.path.splitext(filename)[0] for filename in filenames]
assert len(filenames) == len(gain_a) == len(gain_b)
total_sf = cls.load_det(dirpath, filenames[0])
total_sf = cls.load_det(filenames[0], base_path=dirpath)
total_sf.initialize(mdim)
total_sf.ndet = len(filenames)
assert input_map.shape == (3,len(total_sf.hitmap))
I = input_map[0]
P = input_map[1] + 1j*input_map[2]
for i,filename in enumerate(filenames):
sf = cls.load_det(dirpath, filename)
sf = cls.load_det(filename, base_path=dirpath)
signal_fields = ScanFields.diff_gain_field(gain_a[i], gain_b[i], I, P)
sf.couple(signal_fields, mdim)
total_sf.hitmap += sf.hitmap
Expand Down Expand Up @@ -416,21 +426,20 @@ def diff_pointing_field(
@classmethod
def sim_diff_pointing_channel(
cls,
base_path: str,
channel: str,
mdim: int,
input_map: np.ndarray,
rho_T: np.ndarray, # Pointing offset magnitude
rho_B: np.ndarray,
chi_T: np.ndarray, # Pointing offset direction
chi_B: np.ndarray,
base_path=DB_ROOT_PATH,
):

dirpath = os.path.join(base_path, channel)
filenames = os.listdir(dirpath)
filenames = [os.path.splitext(filename)[0] for filename in filenames]
assert len(filenames) == len(rho_T) == len(chi_T) == len(rho_B) == len(chi_B)
total_sf = cls.load_det(dirpath, filenames[0])
total_sf = cls.load_det(filenames[0], base_path=dirpath)
total_sf.initialize(mdim)
total_sf.ndet = len(filenames)
assert input_map.shape == (3,len(total_sf.hitmap))
Expand All @@ -447,7 +456,7 @@ def sim_diff_pointing_channel(
o_eth_P = dQ[2] - dU[1] + 1j*(dQ[1] + dU[2])

for i,filename in enumerate(filenames):
sf = cls.load_det(dirpath, filename)
sf = cls.load_det(filename, base_path=dirpath)
signal_fields = ScanFields.diff_pointing_field(rho_T[i], rho_B[i], chi_T[i], chi_B[i], P, eth_I, eth_P, o_eth_P)
sf.couple(signal_fields, mdim)
total_sf.hitmap += sf.hitmap
Expand Down Expand Up @@ -732,3 +741,31 @@ def get_instrument_table(imo:Imo, imo_version="v2"):
'telescope' : telescope_list
})
return instrument

def gen_jsonfile(base_path):
dataset = []
scan_field = None
for root, dirs, files in os.walk(base_path):
ch = root.split('/')[-1]
if files:
data = {
"channel": ch,
"detectors": files
}
dataset.append(data)
if ch == "boresight":
scan_field = ScanFields.load_det("boresight", base_path=root)

nside = int(scan_field.nside)
duration = int(scan_field.duration)
scan_strategy = scan_field.ss
considered_spin = scan_field.spins
scaninfo = {
"nside": nside,
"duration": duration,
"scan_strategy": scan_strategy,
"considered_spin": considered_spin
}
with open(os.path.join(base_path, "sim_config.json"), 'w') as f:
json.dump(scaninfo, f, indent=4, default=custom_encoder)
json.dump(dataset, f, indent=4)
2 changes: 1 addition & 1 deletion tests/test_sbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestSBM(unittest.TestCase):
def setUp(self):
print(f"Current directory: {os.getcwd()}")
self.base_path = "maps"
self.scan_field = ScanFields.load_det(self.base_path, "nside128_boresight")
self.scan_field = ScanFields.load_det("nside128_boresight", self.base_path)
self.input_map = hp.read_map("maps/cmb_0000_nside_128_seed_33.fits", field=(0,1,2)) * 1e6
self.nside = hp.npix2nside(len(self.input_map[0]))

Expand Down

0 comments on commit 29231df

Please sign in to comment.