diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..134ddaacd0 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -284,21 +284,27 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y" return template[:, sort_indices], channel_locations[sort_indices, :] -def fit_velocity(peak_times, channel_dist): +def fit_line_robust(x, y, eps=1e-12): """ - Fit velocity from peak times and channel distances using robust Theilsen estimator. + Fit line using robust Theil-Sen estimator (median of pairwise slopes). """ - # from scipy.stats import linregress - # slope, intercept, _, _, _ = linregress(peak_times, channel_dist) + import itertools - from sklearn.linear_model import TheilSenRegressor + # Calculate slope and bias using Theil-Sen estimator + slopes = [] + for (x0, y0), (x1, y1) in itertools.combinations(zip(x, y), 2): + if np.abs(x1 - x0) > eps: + slopes.append((y1 - y0) / (x1 - x0)) + if len(slopes) == 0: # all x are identical + return np.nan, -np.inf + slope = np.median(slopes) + bias = np.median(y - slope * x) - theil = TheilSenRegressor() - theil.fit(peak_times.reshape(-1, 1), channel_dist) - slope = theil.coef_[0] - intercept = theil.intercept_ - score = theil.score(peak_times.reshape(-1, 1), channel_dist) - return slope, intercept, score + # Calculate R2 score + y_pred = slope * x + bias + r2_score = 1 - ((y - y_pred) ** 2).sum() / (((y - y.mean()) ** 2).sum() + eps) + + return slope, r2_score def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs): @@ -354,8 +360,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_above = channel_locations[channels_above] peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above]) - velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above) - if score < min_r2: + inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above) + if score > min_r2 and np.abs(inv_velocity_above) > 1e-9: + velocity_above = 1 / inv_velocity_above + else: velocity_above = np.nan # Compute velocity below @@ -367,8 +375,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) channel_locations_below = channel_locations[channels_below] peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below]) - velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below) - if score < min_r2: + inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below) + if score > min_r2 and np.abs(inv_velocity_below) > 1e-9: + velocity_below = 1 / inv_velocity_below + else: velocity_below = np.nan return velocity_above, velocity_below