From 1eff4757e8efe9a1ebb18187c4d6c91b5f6fd710 Mon Sep 17 00:00:00 2001 From: JasonMH17 <134568474+JasonMH17@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:04:18 -0800 Subject: [PATCH] Copybara import of the project: COPYBARA_INTEGRATE_REVIEW=https://github.com/google/carfac/pull/19 from JasonMH17:master c81cc0fea2437969288c9bb78fb47460ced57af4 PiperOrigin-RevId: 695494768 --- python/jax/carfac.py | 45 ++++++++++++++++++++-------- python/jax/carfac_bench.py | 12 ++++---- python/jax/carfac_float64_test.py | 2 +- python/jax/carfac_test.py | 29 +++++++++++------- python/jax/carfac_util.py | 33 +++++++++++++++----- python/jax/carfac_util_test.py | 50 +++++++++++++++++-------------- 6 files changed, 110 insertions(+), 61 deletions(-) diff --git a/python/jax/carfac.py b/python/jax/carfac.py index 2d622e1..86b53da 100644 --- a/python/jax/carfac.py +++ b/python/jax/carfac.py @@ -2301,7 +2301,9 @@ def run_segment( weights: CarfacWeights, state: CarfacState, open_loop: bool = False, -) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[ + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray +]: """This function runs the entire CARFAC model. That is, filters a 1 or more channel @@ -2335,6 +2337,8 @@ def run_segment( Returns: naps: neural activity pattern + naps_fibers: neural activity of different fibers + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the @@ -2343,6 +2347,7 @@ def run_segment( if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) [n_samp, n_ears] = input_waves.shape + n_fibertypes = SynDesignParameters.n_classes # TODO(honglinyu): add more assertions using checkify. # if n_ears != cfp.n_ears: @@ -2352,6 +2357,7 @@ def run_segment( n_ch = hypers.ears[0].car.n_ch naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result + naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears)) bm = jnp.zeros((n_samp, n_ch, n_ears)) seg_ohc = jnp.zeros((n_samp, n_ch, n_ears)) seg_agc = jnp.zeros((n_samp, n_ch, n_ears)) @@ -2370,7 +2376,7 @@ def run_segment( # Note that we can use naive for loops here because it will make gradient # computation very slow. def run_segment_scan_helper(carry, k): - naps, state, bm, seg_ohc, seg_agc, input_waves = carry + naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry agc_updated = False for ear in range(n_ears): # This would be cleaner if we could just get and use a reference to @@ -2385,9 +2391,14 @@ def run_segment_scan_helper(carry, k): ) if hypers.ears[ear].syn.do_syn: - ihc_out, _, state.ears[ear].syn = syn_step( + ihc_out, firings, state.ears[ear].syn = syn_step( v_recep, ear, weights, state.ears[ear].syn ) + naps_fibers = naps_fibers.at[k, :, :, ear].set(firings) + else: + naps_fibers = naps_fibers.at[k, :, :, ear].set( + jnp.zeros([jnp.shape(ihc_out)[0], n_fibertypes]) + ) # run the AGC update step, decimating internally, agc_updated, state.ears[ear].agc = agc_step( @@ -2420,11 +2431,11 @@ def close_agc_loop_helper( state, ) - return (naps, state, bm, seg_ohc, seg_agc, input_waves), None + return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None return jax.lax.scan( run_segment_scan_helper, - (naps, state, bm, seg_ohc, seg_agc, input_waves), + (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), jnp.arange(n_samp), )[0][:-1] @@ -2442,7 +2453,9 @@ def run_segment_jit( weights: CarfacWeights, state: CarfacState, open_loop: bool = False, -) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> Tuple[ + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray +]: """A JITted version of run_segment for convenience. Care should be taken with the hyper parameters in hypers. If the hypers object @@ -2468,6 +2481,8 @@ def run_segment_jit( Returns: naps: neural activity pattern + naps_fibers: neural activity of the different fiber types + (only populated with non-zeros when ihc_style equals "two_cap_with_syn") state: the updated state of the CARFAC model. BM: The basilar membrane motion seg_ohc & seg_agc are optional extra outputs useful for seeing what the @@ -2483,7 +2498,9 @@ def run_segment_jit_in_chunks_notraceable( state: CarfacState, open_loop: bool = False, segment_chunk_length: int = 32 * 48000, -) -> tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]: +) -> tuple[ + jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray +]: """Runs the jitted segment runner in segment groups. Running CARFAC on an audio segment this way is most useful when running @@ -2526,6 +2543,7 @@ def run_segment_jit_in_chunks_notraceable( if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) naps_out = [] + naps_fibers_out = [] bm_out = [] ohc_out = [] agc_out = [] @@ -2534,10 +2552,11 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape if n_samp >= segment_length: [current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0) - naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( - current_waves, hypers, weights, state, open_loop + naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + run_segment_jit(current_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) + naps_fibers_out.append(naps_fibers_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) @@ -2546,15 +2565,17 @@ def run_segment_jit_in_chunks_notraceable( [n_samp, _] = input_waves.shape # Take the last few items and just run them. if n_samp > 0: - naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit( - input_waves, hypers, weights, state, open_loop + naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = ( + run_segment_jit(input_waves, hypers, weights, state, open_loop) ) naps_out.append(naps_jax) + naps_fibers_out.append(naps_fibers_jax) bm_out.append(bm_jax) ohc_out.append(seg_ohc_jax) agc_out.append(seg_agc_jax) naps_out = np.concatenate(naps_out, 0) + naps_fibers_out = np.concatenate(naps_fibers_out, 0) bm_out = np.concatenate(bm_out, 0) ohc_out = np.concatenate(ohc_out, 0) agc_out = np.concatenate(agc_out, 0) - return naps_out, state, bm_out, ohc_out, agc_out + return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 6a07e6f..ca58785 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -179,7 +179,7 @@ def loss_func( weights: carfac_jax.CarfacWeights, state: carfac_jax.CarfacState, ): - nap_output, _, _, _, _ = carfac_jax.run_segment( + nap_output, _, _, _, _, _ = carfac_jax.run_segment( audio, hypers, weights, state ) return jnp.sum(nap_output), nap_output @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State): # that this benchmark is appropriate. n_samp += 1 state.resume_timing() - naps_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State): for _, segment in enumerate(silence_slices): if segment.shape not in compiled_shapes: compiled_shapes.add(segment.shape) - naps_jax, _, _, _, _ = carfac_jax.run_segment_jit( + naps_jax, _, _, _, _, _ = carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, state_jax, open_loop=False ) naps_jax.block_until_ready() @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State): jax_loop_state = state_jax state.resume_timing() for _, segment in enumerate(run_seg_slices): - seg_naps, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( + seg_naps, _, jax_loop_state, seg_bm, seg_ohc, seg_agc = ( carfac_jax.run_segment_jit( segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False ) @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State): params_jax ) short_silence = jnp.zeros(shape=(n_samp, n_ears)) - naps_jax, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _ = run_segment_function( short_silence, hypers_jax, weights_jax, state_jax, open_loop=False ) # This block ensures calculation. @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State): jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR ).block_until_ready() state.resume_timing() - naps_jax, state_jax, _, _, _ = run_segment_function( + naps_jax, _, state_jax, _, _, _ = run_segment_function( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) if state.range(0) != 1: diff --git a/python/jax/carfac_float64_test.py b/python/jax/carfac_float64_test.py index 921dc7d..3237024 100644 --- a/python/jax/carfac_float64_test.py +++ b/python/jax/carfac_float64_test.py @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state): # A loss function for tests. Note that we shouldn't use `run_segment_jit` # here because it will donate the `state` which causes unnecessary # inconvenience for tests. - naps_jax, state_jax, _, _, _ = carfac_jax.run_segment( + naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment( input_waves, hypers, weights, state, open_loop=False ) # For testing, just fit `naps` to 1. diff --git a/python/jax/carfac_test.py b/python/jax/carfac_test.py index 918d074..246b5f0 100644 --- a/python/jax/carfac_test.py +++ b/python/jax/carfac_test.py @@ -324,18 +324,25 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): state_jax_copied = copy.deepcopy(state_jax) # Only tests the JITted version because this is what we will use. - naps_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit( - run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False - ) - naps_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = ( - carfac_jax.run_segment_jit_in_chunks_notraceable( - run_seg_input, - hypers_jax, - weights_jax, - state_jax_copied, - open_loop=False, + naps_jax, _, _, bm_jax, ohc_jax, agc_jax = ( + carfac_jax.run_segment_jit( + run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) ) + ( + naps_jax_chunked, + _, + _, + bm_chunked, + ohc_chunked, + agc_chunked, + ) = carfac_jax.run_segment_jit_in_chunks_notraceable( + run_seg_input, + hypers_jax, + weights_jax, + state_jax_copied, + open_loop=False, + ) self.assertLess(jnp.max(abs(naps_jax_chunked - naps_jax)), 1e-7) self.assertLess(jnp.max(abs(bm_chunked - bm_jax)), 1e-7) self.assertLess(jnp.max(abs(ohc_chunked - ohc_jax)), 1e-7) @@ -380,7 +387,7 @@ def test_equal_forward_pass( run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Only tests the JITted version because this is what we will use. - naps_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( + naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( carfac_jax.run_segment_jit( run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False ) diff --git a/python/jax/carfac_util.py b/python/jax/carfac_util.py index 9071b09..d789415 100644 --- a/python/jax/carfac_util.py +++ b/python/jax/carfac_util.py @@ -38,6 +38,7 @@ def run_multiple_segment_states_shmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, @@ -88,23 +89,31 @@ def parallel_helper(input_waves, state): """ input_waves = input_waves[0] state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state) - naps, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit( - input_waves, hypers, weights, state, open_loop + naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = ( + carfac_jax.run_segment_jit( + input_waves, hypers, weights, state, open_loop + ) ) ret_state = jax.tree_util.tree_map( lambda x: jnp.asarray(x).reshape((1, -1)), ret_state ) return ( naps[None], + naps_fibers[None], ret_state, bm[None], seg_ohc[None], seg_agc[None], ) - stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = ( - parallel_helper(input_waves_array, batch_state) - ) + ( + stacked_naps, + stacked_naps_fibers, + stacked_states, + stacked_bm, + stacked_ohc, + stacked_agc, + ) = parallel_helper(input_waves_array, batch_state) output_states = _tree_unstack(stacked_states) output = [] # TODO(robsc): Modify this for loop to a jax.lax loop, and then JIT the @@ -112,6 +121,7 @@ def parallel_helper(input_waves, state): for i, output_state in enumerate(output_states): tup = ( stacked_naps[i], + stacked_naps_fibers[i], output_state, stacked_bm[i], stacked_ohc[i], @@ -130,6 +140,7 @@ def run_multiple_segment_pmap( open_loop: bool = False, ) -> Sequence[ Tuple[ + jnp.ndarray, jnp.ndarray, carfac_jax.CarfacState, jnp.ndarray, @@ -155,15 +166,21 @@ def run_multiple_segment_pmap( in_axes=(0, None, None, None, None), static_broadcasted_argnums=[1, 4], ) - stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped( - input_waves_array, hypers, weights, state, open_loop - ) + ( + stacked_naps, + stacked_naps_fibers, + stacked_states, + stacked_bm, + stacked_ohc, + stacked_agc, + ) = pmapped(input_waves_array, hypers, weights, state, open_loop) output_states = _tree_unstack(stacked_states) output = [] for i, output_state in enumerate(output_states): tup = ( stacked_naps[i], + stacked_naps_fibers[i], output_state, stacked_bm[i], stacked_ohc[i], diff --git a/python/jax/carfac_util_test.py b/python/jax/carfac_util_test.py index 753238b..a88c449 100644 --- a/python/jax/carfac_util_test.py +++ b/python/jax/carfac_util_test.py @@ -59,7 +59,7 @@ def test_same_outputs_parallel_for_pmap(self): ], axis=0, ) - nap_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -68,7 +68,7 @@ def test_same_outputs_parallel_for_pmap(self): self.open_loop, ) ) - nap_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -86,20 +86,22 @@ def test_same_outputs_parallel_for_pmap(self): ) self.assertTrue((combined_output[0][0] == nap_out_a).all()) self.assertTrue((combined_output[1][0] == nap_out_b).all()) - self.assertTrue((combined_output[0][2] == bm_out_a).all()) - self.assertTrue((combined_output[1][2] == bm_out_b).all()) - self.assertTrue((combined_output[0][3] == ohc_out_a).all()) - self.assertTrue((combined_output[1][3] == ohc_out_b).all()) - self.assertTrue((combined_output[0][4] == agc_out_a).all()) - self.assertTrue((combined_output[1][4] == agc_out_b).all()) + self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) + self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == ohc_out_a).all()) + self.assertTrue((combined_output[1][4] == ohc_out_b).all()) + self.assertTrue((combined_output[0][5] == agc_out_a).all()) + self.assertTrue((combined_output[1][5] == agc_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -112,7 +114,7 @@ def test_same_outputs_parallel_for_shmap(self): axis=0, ) - nap_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( + nap_out_a, nap_fibers_out_a, state_out_a, bm_out_a, ohc_out_a, agc_out_a = ( carfac.run_segment_jit( self.sample_a, self.hypers, @@ -124,7 +126,7 @@ def test_same_outputs_parallel_for_shmap(self): # Run sample B twice, so we have a separate "starting" state for the # test for shmap. - _, state_out_b_first, _, _, _ = carfac.run_segment_jit( + _, _, state_out_b_first, _, _, _ = carfac.run_segment_jit( self.sample_b, self.hypers, self.weights, @@ -132,7 +134,7 @@ def test_same_outputs_parallel_for_shmap(self): self.open_loop, ) - nap_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( + nap_out_b, nap_fibers_out_b, state_out_b, bm_out_b, ohc_out_b, agc_out_b = ( carfac.run_segment_jit( self.sample_b, self.hypers, @@ -150,20 +152,22 @@ def test_same_outputs_parallel_for_shmap(self): ) self.assertTrue((combined_output[0][0] == nap_out_a).all()) self.assertTrue((combined_output[1][0] == nap_out_b).all()) - self.assertTrue((combined_output[0][2] == bm_out_a).all()) - self.assertTrue((combined_output[1][2] == bm_out_b).all()) - self.assertTrue((combined_output[0][3] == ohc_out_a).all()) - self.assertTrue((combined_output[1][3] == ohc_out_b).all()) - self.assertTrue((combined_output[0][4] == agc_out_a).all()) - self.assertTrue((combined_output[1][4] == agc_out_b).all()) + self.assertTrue((combined_output[0][1] == nap_fibers_out_a).all()) + self.assertTrue((combined_output[1][1] == nap_fibers_out_b).all()) + self.assertTrue((combined_output[0][3] == bm_out_a).all()) + self.assertTrue((combined_output[1][3] == bm_out_b).all()) + self.assertTrue((combined_output[0][4] == ohc_out_a).all()) + self.assertTrue((combined_output[1][4] == ohc_out_b).all()) + self.assertTrue((combined_output[0][5] == agc_out_a).all()) + self.assertTrue((combined_output[1][5] == agc_out_b).all()) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1]) + jax.tree.map(jnp.allclose, state_out_a, combined_output[0][2]) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1]) + jax.tree.map(jnp.allclose, state_out_b, combined_output[1][2]) ) ) @@ -171,12 +175,12 @@ def test_same_outputs_parallel_for_shmap(self): # equality is complete and double sided. self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[0][1], state_out_a) + jax.tree.map(jnp.allclose, combined_output[0][2], state_out_a) ) ) self.assertTrue( jax.tree_util.tree_all( - jax.tree.map(jnp.allclose, combined_output[1][1], state_out_b) + jax.tree.map(jnp.allclose, combined_output[1][2], state_out_b) ) )