Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/spikeinterface/metrics/template/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,21 +284,27 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y"
return template[:, sort_indices], channel_locations[sort_indices, :]


def fit_velocity(peak_times, channel_dist):
def fit_line_robust(x, y, eps=1e-12):
"""
Fit velocity from peak times and channel distances using robust Theilsen estimator.
Fit line using robust Theil-Sen estimator (median of pairwise slopes).
"""
# from scipy.stats import linregress
# slope, intercept, _, _, _ = linregress(peak_times, channel_dist)
import itertools

from sklearn.linear_model import TheilSenRegressor
# Calculate slope and bias using Theil-Sen estimator
slopes = []
for (x0, y0), (x1, y1) in itertools.combinations(zip(x, y), 2):
if np.abs(x1 - x0) > eps:
slopes.append((y1 - y0) / (x1 - x0))
if len(slopes) == 0: # all x are identical
return np.nan, -np.inf
slope = np.median(slopes)
bias = np.median(y - slope * x)

theil = TheilSenRegressor()
theil.fit(peak_times.reshape(-1, 1), channel_dist)
slope = theil.coef_[0]
intercept = theil.intercept_
score = theil.score(peak_times.reshape(-1, 1), channel_dist)
return slope, intercept, score
# Calculate R2 score
y_pred = slope * x + bias
r2_score = 1 - ((y - y_pred) ** 2).sum() / (((y - y.mean()) ** 2).sum() + eps)

return slope, r2_score


def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs):
Expand Down Expand Up @@ -354,8 +360,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
channel_locations_above = channel_locations[channels_above]
peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time
distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above])
velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above)
if score < min_r2:
inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above)
if score > min_r2 and np.abs(inv_velocity_above) > 1e-9:
velocity_above = 1 / inv_velocity_above
else:
velocity_above = np.nan

# Compute velocity below
Expand All @@ -367,8 +375,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
channel_locations_below = channel_locations[channels_below]
peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time
distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below])
velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below)
if score < min_r2:
inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below)
if score > min_r2 and np.abs(inv_velocity_below) > 1e-9:
velocity_below = 1 / inv_velocity_below
else:
velocity_below = np.nan

return velocity_above, velocity_below
Expand Down