Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions pygam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,49 @@ def check_lengths(*arrays):
raise ValueError(f"Inconsistent data lengths: {lengths}")


_CONSTRAINT_OPS = {
">= 0": lambda a: np.all(a >= 0),
">=0": lambda a: np.all(a >= 0),
">= 1": lambda a: np.all(a >= 1),
">=1": lambda a: np.all(a >= 1),
"> 0": lambda a: np.all(a > 0),
">0": lambda a: np.all(a > 0),
"> 1": lambda a: np.all(a > 1),
">1": lambda a: np.all(a > 1),
"<= 0": lambda a: np.all(a <= 0),
"<=0": lambda a: np.all(a <= 0),
"<= 1": lambda a: np.all(a <= 1),
"<=1": lambda a: np.all(a <= 1),
"< 0": lambda a: np.all(a < 0),
"<0": lambda a: np.all(a < 0),
"< 1": lambda a: np.all(a < 1),
"<1": lambda a: np.all(a < 1),
}


def _check_constraint(param_dt, constraint):
"""Check if param_dt satisfies the given constraint string.

Uses direct numpy comparison for known constraints,
falls back to eval for unknown constraint strings.

Parameters
----------
param_dt : np.array
constraint : str, e.g. '>= 0', '>=1'

Returns
-------
bool
"""
check = _CONSTRAINT_OPS.get(constraint)
if check is not None:
return check(param_dt)

# fallback for unknown constraints
return (eval("np." + repr(param_dt) + constraint)).all()


def check_param(param, param_name, dtype, constraint=None, iterable=True, max_depth=2):
"""
Checks the dtype of a parameter,
Expand All @@ -395,23 +438,26 @@ def check_param(param, param_name, dtype, constraint=None, iterable=True, max_de
-------
list of validated and converted parameter(s)
"""
msg = []
msg.append(param_name + " must be " + dtype)
if iterable:
msg.append(
" or nested iterable of depth "
+ str(max_depth)
+ " containing "
+ dtype
+ "s"
)

msg.append(", but found " + param_name + f" = {repr(param)}")
def _build_msg():
msg = []
msg.append(param_name + " must be " + dtype)
if iterable:
msg.append(
" or nested iterable of depth "
+ str(max_depth)
+ " containing "
+ dtype
+ "s"
)

msg.append(", but found " + param_name + f" = {repr(param)}")

if constraint is not None:
msg = (" " + constraint).join(msg)
else:
msg = "".join(msg)
if constraint is not None:
msg = (" " + constraint).join(msg)
else:
msg = "".join(msg)
return msg

# check param is numerical
try:
Expand All @@ -420,23 +466,23 @@ def check_param(param, param_name, dtype, constraint=None, iterable=True, max_de
) # + np.zeros_like(flatten(param), dtype='int')
# param_dt = np.array(param).astype(dtype)
except (ValueError, TypeError):
raise TypeError(msg)
raise TypeError(_build_msg())

# check iterable
if iterable:
if check_iterable_depth(param) > max_depth:
raise TypeError(msg)
raise TypeError(_build_msg())
if (not iterable) and isiterable(param):
raise TypeError(msg)
raise TypeError(_build_msg())

# check param is correct dtype
if not (param_dt == np.array(flatten(param)).astype(float)).all():
raise TypeError(msg)
raise TypeError(_build_msg())

# check constraint
if constraint is not None:
if not (eval("np." + repr(param_dt) + constraint)).all():
raise ValueError(msg)
if not _check_constraint(param_dt, constraint):
raise ValueError(_build_msg())

return param

Expand Down
Loading