diff --git a/src/cell_type_mapper/diff_exp/score_utils.py b/src/cell_type_mapper/diff_exp/score_utils.py index 843d198c..b4b12119 100644 --- a/src/cell_type_mapper/diff_exp/score_utils.py +++ b/src/cell_type_mapper/diff_exp/score_utils.py @@ -187,8 +187,10 @@ def read_raw_precomputed_stats( precomputed_stats['gene_names'] = json.loads( in_file['col_names'][()].decode('utf-8')) - row_lookup = json.loads( - in_file['cluster_to_row'][()].decode('utf-8')) + row_lookup = _read_cluster_to_row( + json.loads( + in_file['cluster_to_row'][()].decode('utf-8')) + ) all_keys = set(['n_cells', 'sum', 'sumsq', 'gt0', 'gt1', 'ge1']) all_keys = list(all_keys.intersection(set(in_file.keys()))) @@ -286,13 +288,13 @@ def aggregate_stats( sumsq_arr += these_stats['sumsq'] if 'gt0' in these_stats: - gt0 += these_stats['gt0'] + gt0 += list(map(int, these_stats['gt0'])) if 'gt1' in these_stats: - gt1 += these_stats['gt1'] + gt1 += list(map(int, these_stats['gt1'])) if 'ge1' in these_stats: - ge1 += these_stats['ge1'] + ge1 += list(map(int, these_stats['ge1'])) else: has_ge1 = False @@ -317,3 +319,24 @@ def aggregate_stats( result[k] = result[k].astype(new_dtype) return result + + +def _read_cluster_to_row(cluster_row_lookup): + """ + Take cluster row lookup dict. + Return cluster row lookup dict with int casted values. + + Parameters + ---------- + cluster_row_lookup: + A dictionary mapping cluster names to index values + for looking up n_cells for a given cluster. + + Returns + ------- + cluster_to_row + """ + cluster_to_row = {} + for cluster, row in cluster_row_lookup.items(): + cluster_to_row[cluster] = int(row) + return cluster_to_row \ No newline at end of file diff --git a/src/cell_type_mapper/utils/output_utils.py b/src/cell_type_mapper/utils/output_utils.py index 293d2f1d..22908cd7 100644 --- a/src/cell_type_mapper/utils/output_utils.py +++ b/src/cell_type_mapper/utils/output_utils.py @@ -555,7 +555,7 @@ def precomputed_stats_to_uns( def uns_to_precomputed_stats( h5ad_path, - uns_key, + uns_keys_list, tmp_dir=None): """ Read a serialized precomputed stats file from the uns element @@ -593,7 +593,10 @@ def uns_to_precomputed_stats( ['n_cells', 'sum', 'sumsq', 'ge1', 'gt1', 'gt0'] ) - serialized_data = read_uns_from_h5ad(h5ad_path)[uns_key] + serialized_data = read_uns_from_h5ad(h5ad_path) + for uns_key in uns_keys_list: + serialized_data = serialized_data[uns_key] + with h5py.File(h5_path, 'w') as dst: for dataset_name in serialized_data: data = serialized_data[dataset_name] diff --git a/tests/utils/test_output_utils.py b/tests/utils/test_output_utils.py index 3d6d0b47..7ec5e0de 100644 --- a/tests/utils/test_output_utils.py +++ b/tests/utils/test_output_utils.py @@ -468,9 +468,11 @@ def test_re_order_blob(tmp_dir_fixture): assert len(new_blob) == len(results_lookup) +@pytest.mark.parametrize('nested_uns', [True, False]) def test_precomputed_stats_to_uns( tmp_dir_fixture, - precomputed_stats_fixture): + precomputed_stats_fixture, + nested_uns): """ Test utility functions to move precomputed stats data from HDF5 to the uns element of an h5ad file and back. @@ -507,16 +509,26 @@ def test_precomputed_stats_to_uns( a_data.write_h5ad(h5ad_path) uns_key = 'serialization_test' + uns_key_nested = 'serialization_test_nested' precomputed_stats_to_uns( precomputed_stats_path=precomputed_stats_fixture, h5ad_path=h5ad_path, uns_key=uns_key) - roundtrip_path = uns_to_precomputed_stats( - uns_key=uns_key, - h5ad_path=h5ad_path, - tmp_dir=tmp_dir_fixture) + if not nested_uns: + roundtrip_path = uns_to_precomputed_stats( + uns_keys_list=[uns_key], + h5ad_path=h5ad_path, + tmp_dir=tmp_dir_fixture) + else: + a_data = anndata.read_h5ad(h5ad_path) + a_data.uns[uns_key] = {uns_key_nested: a_data.uns[uns_key]} + a_data.write_h5ad(h5ad_path) + roundtrip_path = uns_to_precomputed_stats( + uns_keys_list=[uns_key, uns_key_nested], + h5ad_path=h5ad_path, + tmp_dir=tmp_dir_fixture) with h5py.File(precomputed_stats_fixture, 'r') as expected_src: with h5py.File(roundtrip_path, 'r') as actual_src: @@ -540,3 +552,4 @@ def test_precomputed_stats_to_uns( original_uns['maybe'], roundtrip_h5ad.uns['maybe'] ) +