diff --git a/docs/api/operations.md b/docs/api/operations.md index 937b8dbc..3eb2a5a6 100644 --- a/docs/api/operations.md +++ b/docs/api/operations.md @@ -29,4 +29,6 @@ Operations on `SpatialData` objects. .. autofunction:: are_extents_equal .. autofunction:: deepcopy .. autofunction:: get_pyramid_levels +.. autofunction:: sanitize_name +.. autofunction:: sanitize_table ``` diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 9ddfea32..0b68391a 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -53,10 +53,13 @@ "relabel_sequential", "map_raster", "deepcopy", + "sanitize_table", + "sanitize_name", ] from spatialdata import dataloader, datasets, models, transformations from spatialdata._core._deepcopy import deepcopy +from spatialdata._core._utils import sanitize_name, sanitize_table from spatialdata._core.centroids import get_centroids from spatialdata._core.concatenate import concatenate from spatialdata._core.data_extent import are_extents_equal, get_extent diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py index dd27e9c8..a5581565 100644 --- a/src/spatialdata/_core/_utils.py +++ b/src/spatialdata/_core/_utils.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from collections.abc import Iterable +from anndata import AnnData + from spatialdata._core.spatialdata import SpatialData @@ -25,3 +29,138 @@ def _find_common_table_keys(sdatas: Iterable[SpatialData]) -> set[str]: common_keys.intersection_update(sdata.tables.keys()) return common_keys + + +def sanitize_name(name: str, is_dataframe_column: bool = False) -> str: + """ + Sanitize a name to comply with SpatialData naming rules. + + This function converts invalid names into valid ones by: + 1. Converting to string if not already + 2. Removing invalid characters + 3. Handling special cases like "__" prefix + 4. Ensuring the name is not empty + 5. Handling special cases for dataframe columns + + See a discussion on the naming rules, and how to avoid naming collisions, here: + https://github.com/scverse/spatialdata/discussions/707 + + Parameters + ---------- + name + The name to sanitize + is_dataframe_column + Whether this name is for a dataframe column (additional restrictions apply) + + Returns + ------- + A sanitized version of the name that complies with SpatialData naming rules. If a + santized name cannoted be generated, it returns "unnamed". + + Examples + -------- + >>> sanitize_name("my@invalid#name") + 'my_invalid_name' + >>> sanitize_name("__private") + 'private' + >>> sanitize_name("_index", is_dataframe_column=True) + 'index' + """ + # Convert to string if not already + name = str(name) + + # Handle empty string case + if not name: + return "unnamed" + + # Handle special cases + if name in {".", ".."}: + return "unnamed" + + sanitized = "".join(char if char.isalnum() or char in "_-." else "_" for char in name) + + # remove double underscores if found as a prefix + while sanitized.startswith("__"): + sanitized = sanitized[1:] + + if is_dataframe_column and sanitized == "_index": + return "index" + + # Ensure we don't end up with an empty string after sanitization + return sanitized or "unnamed" + + +def sanitize_table(data: AnnData, inplace: bool = True) -> AnnData | None: + """ + Sanitize all keys in an AnnData table to comply with SpatialData naming rules. + + This function sanitizes all keys in obs, var, obsm, obsp, varm, varp, uns, and layers + while maintaining case-insensitive uniqueness. It can either modify the table in-place + or return a new sanitized copy. + + See a discussion on the naming rules here: + https://github.com/scverse/spatialdata/discussions/707 + + Parameters + ---------- + data + The AnnData table to sanitize + inplace + Whether to modify the table in-place or return a new copy + + Returns + ------- + If inplace is False, returns a new AnnData object with sanitized keys. + If inplace is True, returns None as the original object is modified. + + Examples + -------- + >>> import anndata as ad + >>> adata = ad.AnnData(obs=pd.DataFrame({"@invalid#": [1, 2]})) + >>> # Create a new sanitized copy + >>> sanitized = sanitize_table(adata) + >>> print(sanitized.obs.columns) + Index(['invalid_'], dtype='object') + >>> # Or modify in-place + >>> sanitize_table(adata, inplace=True) + >>> print(adata.obs.columns) + Index(['invalid_'], dtype='object') + """ + import copy + from collections import defaultdict + + # Create a deep copy if not modifying in-place + sanitized = data if inplace else copy.deepcopy(data) + + # Track used names to maintain case-insensitive uniqueness + used_names_lower: dict[str, set[str]] = defaultdict(set) + + def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> str: + base_name = sanitize_name(name, is_dataframe_column) + normalized_base = base_name.lower() + + # If this exact name is already used, add a number + if normalized_base in used_names_lower[attr]: + counter = 1 + while f"{base_name}_{counter}".lower() in used_names_lower[attr]: + counter += 1 + base_name = f"{base_name}_{counter}" + + used_names_lower[attr].add(base_name.lower()) + return base_name + + # Handle obs and var (dataframe columns) + for attr in ("obs", "var"): + df = getattr(sanitized, attr) + new_columns = {old: get_unique_name(old, attr, is_dataframe_column=True) for old in df.columns} + df.rename(columns=new_columns, inplace=True) + + # Handle other attributes + for attr in ("obsm", "obsp", "varm", "varp", "uns", "layers"): + d = getattr(sanitized, attr) + new_keys = {old: get_unique_name(old, attr) for old in d} + # Create new dictionary with sanitized keys + new_dict = {new_keys[old]: value for old, value in d.items()} + setattr(sanitized, attr, new_dict) + + return None if inplace else sanitized diff --git a/src/spatialdata/_core/validation.py b/src/spatialdata/_core/validation.py index d10ba357..8c038a4a 100644 --- a/src/spatialdata/_core/validation.py +++ b/src/spatialdata/_core/validation.py @@ -379,5 +379,8 @@ def __exit__( return False # Exceptions were collected that we want to raise as a combined validation error. if self._collector.errors: - raise ValidationError(title=self._message, errors=self._collector.errors) + raise ValidationError( + title=self._message + "\nTo fix, run `spatialdata.utils.sanitize_table(adata)`.", + errors=self._collector.errors, + ) return True diff --git a/tests/utils/test_sanitize.py b/tests/utils/test_sanitize.py new file mode 100644 index 00000000..b4999a48 --- /dev/null +++ b/tests/utils/test_sanitize.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +from spatialdata import SpatialData +from spatialdata._core._utils import sanitize_name, sanitize_table + + +@pytest.fixture +def invalid_table() -> AnnData: + """AnnData with invalid obs column names to test basic sanitization.""" + return AnnData( + obs=pd.DataFrame( + { + "@invalid#": [1, 2], + "valid_name": [3, 4], + "__private": [5, 6], + } + ) + ) + + +@pytest.fixture +def invalid_table_with_index() -> AnnData: + """AnnData with a name requiring whitespace→underscore and a dataframe index column.""" + return AnnData( + obs=pd.DataFrame( + { + "invalid name": [1, 2], + "_index": [3, 4], + } + ) + ) + + +# ----------------------------------------------------------------------------- +# sanitize_name tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw,expected", + [ + ("valid_name", "valid_name"), + ("valid-name", "valid-name"), + ("valid.name", "valid.name"), + ("invalid@name", "invalid_name"), + ("invalid#name", "invalid_name"), + ("invalid name", "invalid_name"), + ("", "unnamed"), + (".", "unnamed"), + ("..", "unnamed"), + ("__", "_"), + ("___", "_"), + ("____#@$@", "_"), + ("__private", "_private"), + ], +) +def test_sanitize_name_strips_special_chars(raw, expected): + assert sanitize_name(raw) == expected + + +@pytest.mark.parametrize( + "raw,is_df_col,expected", + [ + ("_index", True, "index"), + ("_index", False, "_index"), + ("valid@column", True, "valid_column"), + ("__private", True, "_private"), + ], +) +def test_sanitize_name_dataframe_column(raw, is_df_col, expected): + assert sanitize_name(raw, is_dataframe_column=is_df_col) == expected + + +# ----------------------------------------------------------------------------- +# sanitize_table basic behaviors +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_basic_columns(invalid_table, invalid_table_with_index): + ad1 = sanitize_table(invalid_table, inplace=False) + assert isinstance(ad1, AnnData) + assert list(ad1.obs.columns) == ["_invalid_", "valid_name", "_private"] + + ad2 = sanitize_table(invalid_table_with_index, inplace=False) + assert list(ad2.obs.columns) == ["invalid_name", "index"] + + # original fixture remains unchanged + assert list(invalid_table.obs.columns) == ["@invalid#", "valid_name", "__private"] + + +def test_sanitize_table_inplace_copy(invalid_table): + ad = invalid_table.copy() + sanitize_table(ad) # inplace=True is now default + assert list(ad.obs.columns) == ["_invalid_", "valid_name", "_private"] + + +def test_sanitize_table_case_insensitive_collisions(): + obs = pd.DataFrame( + { + "Column1": [1, 2], + "column1": [3, 4], + "COLUMN1": [5, 6], + } + ) + ad = AnnData(obs=obs) + sanitized = sanitize_table(ad, inplace=False) + cols = list(sanitized.obs.columns) + assert sorted(cols) == sorted(["Column1", "column1_1", "COLUMN1_2"]) + + +def test_sanitize_table_whitespace_collision(): + """Ensure 'a b' → 'a_b' doesn't collide silently with existing 'a_b'.""" + obs = pd.DataFrame({"a b": [1], "a_b": [2]}) + ad = AnnData(obs=obs) + sanitized = sanitize_table(ad, inplace=False) + cols = list(sanitized.obs.columns) + assert "a_b" in cols + assert "a_b_1" in cols + + +# ----------------------------------------------------------------------------- +# sanitize_table attribute‐specific tests +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_obs_and_obs_columns(): + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.obs.columns) == ["_col"] + + +def test_sanitize_table_obsm_and_obsp(): + ad = AnnData(obs=pd.DataFrame({"@col": [1, 2]})) + ad.obsm["@col"] = np.array([[1, 2], [3, 4]]) + ad.obsp["bad name"] = np.array([[1, 2], [3, 4]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.obsm.keys()) == ["_col"] + assert list(sanitized.obsp.keys()) == ["bad_name"] + + +def test_sanitize_table_varm_and_varp(): + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad.varm["__priv"] = np.array([[1, 2], [3, 4]]) + ad.varp["_index"] = np.array([[1, 2], [3, 4]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.varm.keys()) == ["_priv"] + assert list(sanitized.varp.keys()) == ["_index"] + + +def test_sanitize_table_uns_and_layers(): + ad = AnnData(obs=pd.DataFrame({"x": [1, 2]}), var=pd.DataFrame(index=["v1", "v2"])) + ad.uns["bad@key"] = "val" + ad.layers["bad#layer"] = np.array([[0, 1], [1, 0]]) + sanitized = sanitize_table(ad, inplace=False) + assert list(sanitized.uns.keys()) == ["bad_key"] + assert list(sanitized.layers.keys()) == ["bad_layer"] + + +def test_sanitize_table_empty_returns_empty(): + ad = AnnData() + sanitized = sanitize_table(ad, inplace=False) + assert isinstance(sanitized, AnnData) + assert sanitized.obs.empty + assert sanitized.var.empty + + +def test_sanitize_table_preserves_underlying_data(): + ad = AnnData(obs=pd.DataFrame({"@invalid#": [1, 2], "valid": [3, 4]})) + ad.obsm["@invalid#"] = np.array([[1, 2], [3, 4]]) + ad.uns["invalid@key"] = "value" + sanitized = sanitize_table(ad, inplace=False) + assert sanitized.obs["_invalid_"].tolist() == [1, 2] + assert sanitized.obs["valid"].tolist() == [3, 4] + assert np.array_equal(sanitized.obsm["_invalid_"], np.array([[1, 2], [3, 4]])) + assert sanitized.uns["invalid_key"] == "value" + + +# ----------------------------------------------------------------------------- +# SpatialData integration +# ----------------------------------------------------------------------------- + + +def test_sanitize_table_in_spatialdata_sanitized_fixture(invalid_table, invalid_table_with_index): + table1 = invalid_table.copy() + table2 = invalid_table_with_index.copy() + sanitize_table(table1) + sanitize_table(table2) + sdata_sanitized_tables = SpatialData(tables={"table1": table1, "table2": table2}) + + t1 = sdata_sanitized_tables.tables["table1"] + t2 = sdata_sanitized_tables.tables["table2"] + assert list(t1.obs.columns) == ["_invalid_", "valid_name", "_private"] + assert list(t2.obs.columns) == ["invalid_name", "index"] + + +def test_spatialdata_retains_other_elements(full_sdata): + # Add another sanitized table into an existing full_sdata + tbl = AnnData(obs=pd.DataFrame({"@foo#": [1, 2], "bar": [3, 4]})) + sanitize_table(tbl) + full_sdata.tables["new_table"] = tbl + + # Verify columns and presence of other SpatialData attributes + assert list(full_sdata.tables["new_table"].obs.columns) == ["_foo_", "bar"]