import numpy as np
from ._core import SortedSegment
from ._core import segment_and_sort_spikes, find_optimal_assignment
[docs]
class Recording:
"""
Represents an extracellular recording stored in a flat binary file.
This class manages metadata and provides efficient access to raw voltage data
for spike sorting. It assumes a flat binary layout with channels interleaved
sample-wise.
Parameters
----------
binary_file : str
Path to the binary file containing the raw recording.
data_type : dtype
NumPy-compatible data type (e.g., ``float32``, ``int16``).
probe_geometry : ndarray of shape (recording_channels, 2)
Spatial coordinates (in micrometers) of each recording channel.
sampling_frequency : float
Sampling rate in Hz. Must be at least 1000.
num_samples : int, optional
Total number of samples to load. If omitted, the number is inferred
from file size and header.
header : int, optional (default=0)
Number of bytes to skip at the beginning of the file.
sample_offset : int, optional (default=0)
Number of samples to skip after the header.
recording_offset : int, optional (default=0)
Logical start time in samples, used for aligning or merging segments.
Does not affect how data are read.
Warning
-------
- After creation, this object should be treated as read-only.
- Binary layout must be flat and channel-interleaved sample-wise.
- The order of channels in `probe_geometry` must match the binary file.
"""
VALID_DATA_TYPES = ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'int64', 'uint64', 'float32', 'float64')
def __init__(self, *, binary_file, data_type, probe_geometry, sampling_frequency, num_samples=None, header=0, sample_offset=0, recording_offset=0):
self.binary_file = binary_file
self.data_type = data_type
self.probe_geometry = probe_geometry
self.sampling_frequency = sampling_frequency
self.num_samples = num_samples
self.header = header
self.sample_offset = sample_offset
self.recording_offset = recording_offset
self.validate()
[docs]
def validate(self, *, verbose=False):
"""
Finalizes setup and verifies recording consistency.
Parameters
----------
verbose : bool, optional
If True, prints a summary of the recording.
Raises
------
ValueError
If any of the file, geometry, or offset parameters are invalid.
Warning
-------
- This method is called automatically during spike sorting.
- Manual calls are typically only necessary for debugging or inspection.
"""
# Validate file accessibility
try:
file_map = np.memmap(self.binary_file, dtype=np.uint8, mode='r')
self.file_size = file_map.size
except Exception as e:
raise ValueError(
"Unable to read binary file "
f"'{self.binary_file}': {e}"
)
# Validate and resolve data type
try:
self.data_type = np.dtype(self.data_type)
except Exception as e:
raise ValueError(
"Unable to resolve data type "
f"'{self.data_type}': {e}"
)
if self.data_type not in (np.dtype(x) for x in self.VALID_DATA_TYPES):
raise ValueError(
f"Invalid data type '{self.data_type}'. "
f"Must be one of {self.VALID_DATA_TYPES}."
)
# Validate probe geometry
if not isinstance(self.probe_geometry, np.ndarray) or self.probe_geometry.ndim != 2 or self.probe_geometry.shape[1] != 2:
raise ValueError("probe_geometry must be a numpy array of shape (recording_channels, 2).")
if self.probe_geometry.dtype.kind not in {'f', 'i', 'u'}:
raise ValueError("probe_geometry must contain numeric values (float or int).")
self.probe_geometry = self.probe_geometry.astype('float32')
self.recording_channels = self.probe_geometry.shape[0]
if self.recording_channels < 1:
raise ValueError("No recording channels provided")
# Validate sampling rate
if not isinstance(self.sampling_frequency, (int, float, np.number)) or self.sampling_frequency < 1000:
raise ValueError("sampling_frequency must be at least 1000 Hz.")
self.samples_per_ms = int(self.sampling_frequency // 1000)
# Validate offsets
if not isinstance(self.header, (int, np.integer)) or self.header < 0:
raise ValueError("header must be a non-negative integer.")
if not isinstance(self.sample_offset, (int, np.integer)) or self.sample_offset < 0:
raise ValueError("offset must be a non-negative integer.")
if not isinstance(self.recording_offset, (int, np.integer)):
raise ValueError("recording_offset must be an integer.")
# Check total offset
self.bytes_per_sample = self.recording_channels * self.data_type.itemsize
if self.header + self.sample_offset * self.bytes_per_sample >= self.file_size:
raise ValueError("Offset exceeds file size; no valid data available.")
available_bytes = self.file_size - (self.header + self.sample_offset * self.bytes_per_sample)
# Validate num_samples or infer it
if self.num_samples is None:
if available_bytes % self.bytes_per_sample != 0:
raise ValueError(
"File is not aligned to full samples "
"(specify num_samples if you want to proceed)."
)
self.num_samples = available_bytes // self.bytes_per_sample
else:
if not isinstance(self.num_samples, (int, np.integer)) or self.num_samples <= 0:
raise ValueError(f"num_samples must be a positive integer, but got {self.num_samples}.")
max_possible_samples = available_bytes // self.bytes_per_sample
if self.num_samples > max_possible_samples:
raise ValueError(
f"Requested num_samples ({self.num_samples}) "
f"exceeds available samples ({max_possible_samples}) "
"based on file size."
)
if verbose:
print(
"=" * 60 + "\n"
" SpikeSift - Recording Info\n"
+ "-" * 60 + "\n"
f" File path : {self.binary_file}\n"
f" Data type : {self.data_type}\n"
f" Header offset : {self.header} bytes\n"
f" Sample offset : {self.sample_offset} samples\n"
f" Recording offset : {self.recording_offset} samples\n"
f" Total duration : {self.num_samples / self.sampling_frequency} seconds\n"
f" Number of channels : {self.recording_channels}\n"
f" Sampling frequency : {self.sampling_frequency} Hz\n"
+ "=" * 60 + "\n"
)
[docs]
def read(self, *, start, num_samples):
"""
Reads a segment of the binary recording.
Parameters
----------
start : int
Sample index to begin reading, after accounting for ``header`` and ``sample_offset``.
num_samples : int
Number of consecutive samples to read.
Returns
-------
ndarray, shape ``(num_samples, recording_channels)``
Extracted signal data as a NumPy array.
Warning
-------
- This method is intended for debugging and manual inspection only.
- SpikeSift handles all necessary data access internally during sorting.
"""
byte_offset = self.header + self.bytes_per_sample * (
self.sample_offset + start
)
num_values = num_samples * self.recording_channels
return np.fromfile(
self.binary_file,
dtype = self.data_type,
offset = byte_offset,
count = num_values
).reshape(num_samples, self.recording_channels)
[docs]
class SortedRecording:
"""
Represents a fully sorted and drift-corrected extracellular recording.
This class merges spike clusters across multiple independently sorted segments,
and provides access to global spike times, amplitude vectors, and segment boundaries.
Parameters
----------
sorted_segments : list of SortedSegment (internal)
List of sorted segments, each containing spike clusters and amplitude representations.
assignment_chain : list of ndarray of shape ``(num_clusters,)``
One-to-one mappings between adjacent segments.
- Each array maps cluster indices from one segment to the next.
- Unassigned entries are marked with -1.
probe_geometry : ndarray of shape ``(recording_channels, 2)``
2D electrode layout used for drift compensation.
Warning
-------
- Do not modify ``sorted_segments``, ``assignment_chain``, or ``probe_geometry`` in place.
They are shared across recordings and treated as immutable.
"""
def __init__(self, *, sorted_segments, assignment_chain, probe_geometry):
self.sorted_segments = sorted_segments
self.assignment_chain = assignment_chain
self.probe_geometry = probe_geometry
[docs]
def start_time(self):
"""
Returns the global start time of the recording (in samples).
Returns
-------
int
Start time in samples.
"""
return self.sorted_segments[0].start_time
[docs]
def end_time(self):
"""
Returns the global end time of the recording (in samples).
Returns
-------
int
End time in samples.
"""
last = self.sorted_segments[-1]
return last.start_time + last.duration
[docs]
def segment_boundaries(self):
"""
Returns start and end sample indices for all segments.
Returns
-------
list of tuple
List of ``(start_sample, end_sample)`` pairs, one per segment.
"""
return [
(seg.start_time, seg.start_time + seg.duration)
for seg in self.sorted_segments
]
[docs]
def valid_cluster_id(self, cluster_id):
"""
Checks whether a cluster ID is valid across the entire recording.
Parameters
----------
cluster_id : int
The cluster ID to validate.
Returns
-------
bool
True if the cluster is consistently matched across all segments; False otherwise.
Warning
-------
- A cluster is considered valid only if it is present in every segment of the recording.
- Clusters that disappear or fragment in later segments will return False.
"""
if not (
isinstance(cluster_id, (int, np.integer))
and 0 <= cluster_id < len(self.sorted_segments[0])
):
return False
for assignment in self.assignment_chain:
cluster_id = assignment[cluster_id]
if cluster_id == -1:
return False
return True
[docs]
def cluster_ids(self):
"""
Returns all valid cluster IDs for this recording.
Returns
-------
set of int
Set of cluster IDs that are valid across the entire recording.
Warning
-------
- IDs may refer to different units across different SortedRecording objects.
- To compare clusters between recordings, use :func:`~spikesift.map_clusters`.
"""
return {
cid for cid in range(len(self.sorted_segments[0]))
if self.valid_cluster_id(cid)
}
[docs]
def spikes(self, cluster_id):
"""
Returns spike times for the specified cluster.
Parameters
----------
cluster_id : int
The cluster ID to retrieve.
Returns
-------
ndarray
1D NumPy array of spike times for the selected cluster.
Raises
------
ValueError
If the cluster ID is not valid for this recording.
Warning
-------
- Cluster IDs are only valid within this SortedRecording instance.
- To avoid invalid lookups, use `.cluster_ids()` to retrieve the set of valid cluster IDs.
"""
if not self.valid_cluster_id(cluster_id):
raise ValueError(
f"Cluster ID {cluster_id} "
"is not valid for this SortedRecording."
)
first = self.sorted_segments[0]
spike_times = [first.spike_clusters[cluster_id]]
for i, assignment in enumerate(self.assignment_chain, start=1):
cluster_id = assignment[cluster_id]
spike_times.append(
self.sorted_segments[i].spike_clusters[cluster_id]
)
return np.concatenate(spike_times)
[docs]
def amplitude_vectors(self, cluster_id):
"""
Returns the amplitude vectors for a single cluster across all segments.
Parameters
----------
cluster_id : int
ID of the cluster.
Returns
-------
ndarray of shape (num_segments, recording_channels)
Amplitude vector for each segment.
Raises
------
ValueError
If the cluster ID is not valid for this recording.
Warning
-------
- Values reflect both spike-related activity and background
fluctuations, and may be nonzero even on channels where the neuron is inactive.
"""
if not self.valid_cluster_id(cluster_id):
raise ValueError(
f"Cluster ID {cluster_id} "
"is not valid for this SortedRecording."
)
first = self.sorted_segments[0]
vectors = [first.amplitude_vectors[cluster_id]]
for i, assignment in enumerate(self.assignment_chain, start=1):
cluster_id = assignment[cluster_id]
vectors.append(
self.sorted_segments[i].amplitude_vectors[cluster_id]
)
return np.stack(vectors)
[docs]
def all_spikes(self):
"""
Returns spike times for all valid clusters.
Returns
-------
dict of int -> ndarray
Dictionary mapping cluster IDs to spike times.
"""
return {
cid: self.spikes(cid)
for cid in self.cluster_ids()
}
[docs]
def split_into_segments(self):
"""
Splits the recording into its original unmerged segments.
Returns
-------
list of SortedRecording
Each entry corresponds to one original segment.
"""
return [
SortedRecording(
sorted_segments=[seg],
assignment_chain=[],
probe_geometry=self.probe_geometry
)
for seg in self.sorted_segments
]
def __len__(self):
"""
Returns the number of valid clusters in the recording.
Returns
-------
int
Number of globally aligned clusters.
"""
return len(self.cluster_ids())
[docs]
def merge_recordings(
sorted_recordings,
*,
max_drift = 30
):
"""
Aligns and merges multiple independently sorted recordings into a unified result.
Parameters
----------
sorted_recordings : list of SortedRecording
List of independently sorted recordings to be merged. Each entry must:
- Contain at least one valid segment
- Use the same probe geometry
- Be sorted in time and have non-overlapping segments
max_drift : float, optional (default=30)
Maximum vertical shift (in micrometers) allowed when aligning clusters across segments.
- Must be non-negative
- Internally rounded to the nearest multiple of 5
- Higher values allow alignment over larger displacements
Returns
-------
SortedRecording
A single merged recording containing all aligned spike clusters.
Raises
------
ValueError
If the input list is empty, contains invalid types, includes inconsistent geometries,
or includes overlapping segment time ranges.
Warning
-------
- This function assumes all inputs were produced by SpikeSift and remain unmodified.
"""
# Validate max_drift
if not isinstance(max_drift, (int, float, np.number)) or max_drift < 0:
raise ValueError(f"`max_drift` must be a non-negative float or int (got {max_drift}).")
max_drift = int(round(max_drift / 5)) * 5
# Validate input list
if not isinstance(sorted_recordings, list) or len(sorted_recordings) == 0:
raise ValueError("`sorted_recordings` must be a non-empty list of SortedRecording objects.")
# Validate geometry
def validate_geometry(arr, label):
if not (
isinstance(arr, np.ndarray)
and arr.dtype == np.float32
and arr.ndim == 2
and arr.shape[1] == 2
and arr.flags['C_CONTIGUOUS']
):
raise ValueError(
f"{label} must be a C-contiguous float32 array of shape (n, 2). "
"This may indicate that one of the SortedRecording objects or their segments was modified manually. "
"Note: segments are shared between recordings and should be treated as immutable."
)
for i, rec in enumerate(sorted_recordings):
if not isinstance(rec, SortedRecording):
raise ValueError(f"Item {i} in `sorted_recordings` is not a SortedRecording.")
validate_geometry(rec.probe_geometry, f"probe_geometry of recording {i}")
probe_geometry = sorted_recordings[0].probe_geometry
num_channels = probe_geometry.shape[0]
for i in range(1, len(sorted_recordings)):
if not np.array_equal(sorted_recordings[i].probe_geometry, probe_geometry):
raise ValueError(f"probe_geometry of recording {i} does not match earlier recordings.")
if sorted_recordings[i - 1].end_time() > sorted_recordings[i].start_time():
raise ValueError(
f"`sorted_recordings` must be ordered and non-overlapping in time: "
f"recording {i - 1} ends at {sorted_recordings[i - 1].end_time()}, "
f"which overlaps with recording {i} starting at {sorted_recordings[i].start_time()}."
)
# Validate amplitude vectors
def validate_amplitude_vectors(vecs, label):
if not (
isinstance(vecs, np.ndarray)
and vecs.dtype == np.float32
and vecs.ndim == 2
and vecs.shape[1] == num_channels
and vecs.flags['C_CONTIGUOUS']
):
raise ValueError(
f"{label} must be a C-contiguous float32 array of shape (num_clusters, {num_channels}). "
"This may indicate that one of the SortedRecording objects or their segments was modified manually. "
"Note: segments are shared between recordings and should be treated as immutable."
)
# Flatten all segments
sorted_segments = [
segment for rec in sorted_recordings
for segment in rec.sorted_segments
]
# Merge assignment chains
updated_assignment_chain = []
for i, rec in enumerate(sorted_recordings):
updated_assignment_chain.extend(rec.assignment_chain)
# Bridge between adjacent recordings
if i < len(sorted_recordings) - 1:
seg_A = rec.sorted_segments[-1]
seg_B = sorted_recordings[i + 1].sorted_segments[0]
validate_amplitude_vectors(
seg_A.amplitude_vectors,
f"Amplitude vectors of last segment in recording {i}"
)
validate_amplitude_vectors(
seg_B.amplitude_vectors,
f"Amplitude vectors of first segment in recording {i + 1}"
)
bridge_assignment = find_optimal_assignment(
seg_A.amplitude_vectors,
seg_B.amplitude_vectors,
probe_geometry=probe_geometry,
max_drift=max_drift
)
updated_assignment_chain.append(bridge_assignment)
return SortedRecording(
sorted_segments=sorted_segments,
assignment_chain=updated_assignment_chain,
probe_geometry=probe_geometry
)
[docs]
def map_clusters(
source,
target,
*,
max_drift=30
):
"""
Computes a one-to-one mapping from clusters in ``source`` to their counterparts in ``target``.
Parameters
----------
source : SortedRecording
First sorted recording to compare.
target : SortedRecording
Second sorted recording to compare.
max_drift : float, optional (default=30)
Maximum vertical displacement (in micrometers) used during alignment.
- Must be non-negative
- Internally rounded to the nearest multiple of 5
- Higher values permit alignment across larger drift magnitudes
Returns
-------
dict of int -> int
Mapping from cluster IDs in ``source`` to corresponding cluster IDs in ``target``.
Only valid, unambiguous one-to-one matches are included.
Raises
------
ValueError
If inputs are invalid or incompatible (e.g., mismatched geometry).
Warning
-------
- This function assumes that both ``source`` and ``target`` were generated using SpikeSift
and have not been manually modified.
"""
# Validate and round drift threshold
if not isinstance(max_drift, (int, float, np.number)) or max_drift < 0:
raise ValueError(f"`max_drift` must be a non-negative float or int (got {max_drift}).")
max_drift = int(round(max_drift / 5)) * 5
# Validate input types
if not isinstance(source, SortedRecording):
raise ValueError("`source` must be a SortedRecording instance.")
if not isinstance(target, SortedRecording):
raise ValueError("`target` must be a SortedRecording instance.")
def validate_geometry(arr, label):
if not (
isinstance(arr, np.ndarray)
and arr.dtype == np.float32
and arr.ndim == 2
and arr.shape[1] == 2
and arr.flags['C_CONTIGUOUS']
):
raise ValueError(
f"{label} must be a C-contiguous float32 array of shape (n, 2). "
"This may indicate that one of the SortedRecording objects or their segments was modified manually. "
"Note: segments are shared between recordings and should be treated as immutable."
)
validate_geometry(
source.probe_geometry,
"source.probe_geometry"
)
validate_geometry(
target.probe_geometry,
"target.probe_geometry"
)
if not np.array_equal(source.probe_geometry, target.probe_geometry):
raise ValueError("`probe_geometry` must match exactly between source and target.")
# Validate amplitude vectors
probe_geometry = source.probe_geometry
num_channels = probe_geometry.shape[0]
source_amplitudes = source.sorted_segments[0].amplitude_vectors
target_amplitudes = target.sorted_segments[0].amplitude_vectors
def validate_amplitude_vectors(vecs, label):
if not (
isinstance(vecs, np.ndarray)
and vecs.dtype == np.float32
and vecs.ndim == 2
and vecs.shape[1] == num_channels
and vecs.flags['C_CONTIGUOUS']
):
raise ValueError(
f"{label} must be a C-contiguous float32 array of shape (num_clusters, {num_channels}). "
"This may indicate that one of the SortedRecording objects or their segments was modified manually. "
"Note: segments are shared between recordings and should be treated as immutable."
)
validate_amplitude_vectors(
source_amplitudes,
"source.amplitude_vectors"
)
validate_amplitude_vectors(
target_amplitudes,
"target.amplitude_vectors"
)
# Perform optimal assignment using amplitude-based alignment
assignment = find_optimal_assignment(
source_amplitudes,
target_amplitudes,
probe_geometry=probe_geometry,
max_drift=max_drift
)
# Construct cluster ID mapping
source_ids = source.cluster_ids()
target_ids = target.cluster_ids()
cluster_map = {}
for i, j in enumerate(assignment):
if i in source_ids and j in target_ids:
cluster_map[i] = j
return cluster_map