|
13 | 13 | _atol_for_type, |
14 | 14 | _average, |
15 | 15 | _convert_to_numpy, |
| 16 | + _count_nonzero, |
16 | 17 | _estimator_with_converted_arrays, |
17 | 18 | _is_numpy_namespace, |
18 | 19 | _nanmax, |
|
30 | 31 | _array_api_for_tests, |
31 | 32 | skip_if_array_api_compat_not_configured, |
32 | 33 | ) |
33 | | -from sklearn.utils.fixes import _IS_32BIT |
| 34 | +from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS |
34 | 35 |
|
35 | 36 |
|
36 | 37 | @pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]]) |
@@ -530,3 +531,37 @@ def test_get_namespace_and_device(): |
530 | 531 | assert namespace is xp_torch |
531 | 532 | assert is_array_api |
532 | 533 | assert device == some_torch_tensor.device |
| 534 | + |
| 535 | + |
| 536 | +@pytest.mark.parametrize( |
| 537 | + "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations() |
| 538 | +) |
| 539 | +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) |
| 540 | +@pytest.mark.parametrize("axis", [0, 1, None, -1, -2]) |
| 541 | +@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"]) |
| 542 | +def test_count_nonzero( |
| 543 | + array_namespace, device, dtype_name, csr_container, axis, sample_weight_type |
| 544 | +): |
| 545 | + |
| 546 | + from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero |
| 547 | + |
| 548 | + xp = _array_api_for_tests(array_namespace, device) |
| 549 | + array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]]) |
| 550 | + if sample_weight_type == "int": |
| 551 | + sample_weight = numpy.asarray([1, 2, 2, 3, 1]) |
| 552 | + elif sample_weight_type == "float": |
| 553 | + sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name) |
| 554 | + else: |
| 555 | + sample_weight = None |
| 556 | + expected = sparse_count_nonzero( |
| 557 | + csr_container(array), axis=axis, sample_weight=sample_weight |
| 558 | + ) |
| 559 | + array_xp = xp.asarray(array, device=device) |
| 560 | + |
| 561 | + with config_context(array_api_dispatch=True): |
| 562 | + result = _count_nonzero( |
| 563 | + array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight |
| 564 | + ) |
| 565 | + |
| 566 | + assert_allclose(_convert_to_numpy(result, xp=xp), expected) |
| 567 | + assert getattr(array_xp, "device", None) == getattr(result, "device", None) |
0 commit comments