-
Notifications
You must be signed in to change notification settings - Fork 83
Calculating the matrix of pairwise coalescence rates efficiently (with genetic_relatedness_vector) #3444
Description
TreeSequence.pair_coalescence_rates is intended for mapping from a small number of sample sets to a large number of time windows, and is not efficient when one wants to get all pairwise rates between samples for a small number of time windows.
In the latter case, a good option is to use time-windowed genetic_relatedness_vector. This works because the proportion of the sequence that is uncoalesced by a particular timepoint is proportional to the instantaneous change in branch divergence/relatedness for the pair of haplotypes. Once we implement time windowing in genetic_relatedness_vector and don't need to use decapitate per time window, I suspect this approach should be pretty optimal in terms of efficiency.
Then, the pairwise coalescence time CDF can be turned into rates via the same Kaplan-Meier estimator that is used for pair_coalescence_rates.
Here's a mock implementation of computing the CDF this way,
def pairwise_coalescence_cdf(
ts: tskit.TreeSequence,
time_grid: np.ndarray,
samples: np.ndarray = None,
) -> np.ndarray:
"""
Return the proportion of sequence that has coalesced between all pairs of `samples`,
by the times in `time_grid`. If `samples` is None, then all samples in `ts`
are used. The output array has dimensions `(time_grid.size, samples.size, samples.size)`.
This uses `ts.decapitate(t).genetic_relatedness_vector(..., nodes=samples)`
under the hood.
"""
assert time_grid[0] == 0
assert time_grid[-1] < np.inf
assert np.all(np.diff(time_grid) > 0)
if samples is None: samples = np.array(list(ts.samples()))
assert np.all(ts.nodes_time[samples] == 0)
eye = np.eye(samples.size)
times = np.append(np.unique(ts.nodes_time), np.inf)
break_above = np.searchsorted(times, time_grid, side="right")
assert break_above.min() > 0
times_below = times[break_above - 1]
times_above = times[break_above]
pairwise_surv = np.zeros((time_grid.size, samples.size, samples.size))
pairwise_surv[0] = 1.0
kwargs = {"nodes": samples, "mode": "branch", "centre": False}
for i, (a, b) in enumerate(zip(times_below, times_above)):
if a > 0 and np.isfinite(b):
D = (
ts.decapitate(b).genetic_relatedness_vector(eye, **kwargs) -
ts.decapitate(a).genetic_relatedness_vector(eye, **kwargs)
) / (b - a) / 2
pairwise_surv[i] = np.add.outer(np.diag(D), np.diag(D)) - 2 * D
return 1 - pairwise_surv