diff --git a/pypfopt/risk_models.py b/pypfopt/risk_models.py index cf62509f..f526c9db 100644 --- a/pypfopt/risk_models.py +++ b/pypfopt/risk_models.py @@ -25,6 +25,7 @@ import numpy as np import pandas as pd +from skbase.utils.dependencies import _check_soft_dependencies from .expected_returns import returns_from_prices @@ -298,11 +299,14 @@ def min_cov_determinant( warnings.warn("data is not in a dataframe", RuntimeWarning) prices = pd.DataFrame(prices) - # Extra dependency - try: - import sklearn.covariance - except (ModuleNotFoundError, ImportError): - raise ImportError("Please install scikit-learn via pip or poetry") + if not _check_soft_dependencies(["scikit-learn"], severity="none"): + raise ImportError( + "scikit-learn is required to use min_cov_determinant. " + "Please ensure that scikit-learn is installed in your environment," + " e.g via pip install scikit-learn" + ) + + from sklearn.covariance import fast_mcd assets = prices.columns @@ -312,7 +316,7 @@ def min_cov_determinant( X = returns_from_prices(prices, log_returns) # X = np.nan_to_num(X.values) X = X.dropna().values - raw_cov_array = sklearn.covariance.fast_mcd(X, random_state=random_state)[1] + raw_cov_array = fast_mcd(X, random_state=random_state)[1] cov = pd.DataFrame(raw_cov_array, index=assets, columns=assets) * frequency return fix_nonpositive_semidefinite(cov, kwargs.get("fix_method", "spectral")) @@ -379,13 +383,16 @@ def __init__(self, prices, returns_data=False, frequency=252, log_returns=False) :param log_returns: whether to compute using log returns :type log_returns: bool, defaults to False """ - # Optional import - try: - from sklearn import covariance + if not _check_soft_dependencies(["scikit-learn"], severity="none"): + raise ImportError( + "scikit-learn is required to use CovarianceShrinkage. " + "Please ensure that scikit-learn is installed in your environment," + " e.g via pip install scikit-learn" + ) + + from sklearn import covariance - self.covariance = covariance - except (ModuleNotFoundError, ImportError): # pragma: no cover - raise ImportError("Please install scikit-learn via pip or poetry") + self.covariance = covariance if not isinstance(prices, pd.DataFrame): warnings.warn("data is not in a dataframe", RuntimeWarning) diff --git a/pyproject.toml b/pyproject.toml index f18b8fbc..12788579 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,9 +35,9 @@ dependencies = [ "cvxpy>=1.1.19", "numpy>=1.26.0", "pandas>=0.19", + "scikit-base<0.14.0", "scikit-learn>=0.24.1", "scipy>=1.3.0", - "scikit-base<0.14.0", ] [project.optional-dependencies] @@ -54,7 +54,6 @@ dependencies = [ all_extras = [ "matplotlib>=3.2.0", "plotly>=5.0.0,<6", - "scikit-learn>=0.24.1", "ecos>=2.0.14,<2.1", "plotly>=5.0.0,<7", "cvxopt; python_version < '3.14'",