Source code for rtcog.preproc.preproc_steps

import multiprocessing as mp
import os.path as osp
import numpy as np
from sklearn.preprocessing import StandardScaler
from scipy.signal.windows import exponential

from rtcog.utils.log import get_logger
from rtcog.utils.exceptions import VolumeOverflowError
from rtcog.preproc.helpers.preproc_utils import gen_polort_regressors
from rtcog.preproc.helpers.preproc_utils import rt_smooth_vol, calculate_spc
from rtcog.preproc.helpers.preproc_utils import CircularBuffer
from rtcog.preproc.helpers.iglm import iGLM
from rtcog.preproc.helpers.kalman_filter import KalmanFilter
from rtcog.preproc.step_types import StepType
from rtcog.utils.fMRI import unmask_fMRI_img


log = get_logger()

[docs] class PreprocStep: """ Base class for a preprocessing step in the real-time fMRI pipeline. Subclasses must implement the `_run(pipeline)` method, which is called on each TR and receives access to the pipeline’s state. Optionally, subclasses may also implement `_start(pipeline)` to initialize state at the first TR, and `_save(pipeline)` to perform any custom saving logic at the end of the run. Attributes ---------- save : bool Whether to save the output from this step to disk. suffix : str or None Filename suffix to use for saving a NIfTI file (if `save=True`). Nv: int Number of voxels. Nt: int Number of time points. data_out : np.ndarray or None Cached 2D array of shape (N_voxels, N_timepoints) storing output for each TR, populated if `save=True`. Class Attributes ---------------- registry : dict Mapping of registered step names (e.g., "ema", "iglm") to class objects. Automatically populated via `__init_subclass__`. Methods ------- start_step(pipeline): Optional setup logic to initialize internal state before processing begins. run(pipeline): Executes the step’s logic on the current TR. Saves result to `data_out` if saving is enabled. end_step(pipeline): Optional cleanup logic to perform after processing has completed. save_nifti(pipeline): Saves the accumulated data to disk as a NIfTI file using the provided mask and filename. get_class(name): Class method that retrieves a registered step class by name (case-insensitive). """ registry = {} # Holds all available step classes that can be instantiated later on in Pipeline. def __init__(self, *, save=False, suffix=None, Nv, Nt, **kwargs): self.save = save self.suffix = suffix self.Nv = Nv self.Nt = Nt if self.save: self.data_out = np.zeros((self.Nv, self.Nt)) @property def name(self): return self.__class__.__name__.replace('Step', '').lower() def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Skip abstract or helper base classes if cls.__name__ == "PreprocStep" or cls.__name__.startswith("_"): return # Strip "Step" from the end of class names name = cls.__name__ if name.endswith("Step"): name = name[:-4] name = name.lower() cls.registry[name] = cls def __eq__(self, other): """Allows lookup of object in pipeline.steps list""" if isinstance(other, str): return other.lower() in (self.__class__.__name__.lower(), self.__class__.__name__.replace('Step', '').lower()) elif isinstance(other, PreprocStep): return self.__class__ == other.__class__ return False
[docs] @classmethod def get_class(cls, name): return cls.registry.get(name.lower())
[docs] def start_step(self, pipeline): self._start(pipeline)
[docs] def run(self, pipeline): output = self._run(pipeline) if self.save: self.data_out[:, pipeline.t] = output[:, 0] return output
[docs] def end_step(self, pipeline): pass
[docs] def save_nifti(self, pipeline): if self.save and self.suffix: out_path = osp.join(pipeline.out_dir, pipeline.out_prefix + self.suffix) unmask_fMRI_img(self.data_out, pipeline.mask_img, out_path) self._save(pipeline)
def _start(self, pipeline): pass def _run(self, pipeline): raise NotImplementedError def _save(self, pipeline): pass
[docs] def snapshot(self): """Return data_out for testing purposes""" return {self.name: self.data_out}
[docs] class EMAStep(PreprocStep): """Exponential moving average""" def __init__(self, *, save=False, suffix='.pp_EMA.nii', Nv, Nt, alpha=0.98): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.alpha = alpha self.filt = None def _run(self, pipeline): data = pipeline.Data_FromAFNI[:, :pipeline.t + 1] if pipeline.n == 1: # First step self.filt = data[:,-1][:, np.newaxis] return (data[:,-1] - data[:,-2])[:,np.newaxis] return self._apply_filter(data[:,-1][:,np.newaxis]) def _apply_filter(self, data): """Apply the EMA filter to new data.""" A = np.array([[self.alpha, 1 - self.alpha]]) prev_filt = self.filt self.filt = np.dot(A, np.hstack([prev_filt, data]).T).T return data - self.filt
[docs] class iGLMStep(PreprocStep): """Incremental generalized linear model""" def __init__(self, *, save=False, suffix='.pp_iGLM.nii', Nv, Nt, num_polorts=2, iGLM_motion=True): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.num_polorts = num_polorts self.iGLM_motion = iGLM_motion self.iglm = iGLM() if self.save: self.iGLM_Coeffs = None if self.iGLM_motion: self.iGLM_num_regressors = self.num_polorts + 6 self.nuisance_labels = ['Polort'+str(i) for i in np.arange(self.num_polorts)] + ['roll','pitch','yaw','dS','dL','dP'] else: self.iGLM_num_regressors = self.num_polorts self.nuisance_labels = ['Polort'+str(i) for i in np.arange(self.num_polorts)] if self.num_polorts > -1: self.legendre_pols = gen_polort_regressors(self.num_polorts, self.Nt) else: self.legendre_pols = None if self.save: self.iGLM_Coeffs = np.zeros((self.Nv, self.iGLM_num_regressors, self.Nt)) def _run(self, pipeline): try: if self.iGLM_motion: this_t_nuisance = np.concatenate((self.legendre_pols[pipeline.t,:], pipeline.motion))[:,np.newaxis] else: this_t_nuisance = (self.legendre_pols[pipeline.t,:])[:,np.newaxis] except IndexError: raise VolumeOverflowError() iGLM_data_out, Bn = self.iglm.regress_vol( pipeline.n, pipeline.processed_tr, this_t_nuisance, ) if self.save: self.iGLM_Coeffs[:, :, pipeline.t] = np.squeeze(Bn, axis=2) return iGLM_data_out def _save(self, pipeline): if self.save: for i, lab in enumerate(self.nuisance_labels): data = self.iGLM_Coeffs[:,i,:] unmask_fMRI_img(data, pipeline.mask_img, osp.join(pipeline.out_dir, pipeline.out_prefix+'.pp_iGLM_'+lab+'.nii'))
[docs] class KalmanStep(PreprocStep): """Kalman filter for low pass filtering, spike removal, and signal smoothing""" def __init__(self, *, save=False, suffix='.pp_Kalman_LPfilter.nii', Nv, Nt, n_cores=10): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.n_cores = n_cores self.pool = mp.Pool(processes=self.n_cores) self.kalman_filter = KalmanFilter(self.Nv, self.n_cores, self.pool) log.info(f'Initializing Kalman pool with {self.n_cores} processes ...') self.kalman_filter.initialize_pool() def _run(self, pipeline): self.kalman_filter.update_welford(pipeline.n, pipeline.Data_FromAFNI[:, pipeline.t]) klm_data_out = self.kalman_filter.run_volume(pipeline.n, pipeline.processed_tr) return klm_data_out
[docs] def end_step(self, pipeline): if hasattr(self, 'pool'): self.pool.close() self.pool.join() del self.pool
[docs] class SmoothStep(PreprocStep): """Smoothing with Gaussian filter""" def __init__(self, *, save=False, suffix='.pp_Smooth.nii', Nv, Nt, fwhm=4): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.fwhm = fwhm def _run(self, pipeline): return rt_smooth_vol(pipeline.processed_tr, pipeline.mask_img, fwhm=self.fwhm)
[docs] class SnormStep(PreprocStep): """Spatial normalization""" def __init__(self, *, save=False, Nv, Nt, suffix='.pp_Zscore.nii'): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) def _run(self, pipeline): sc = StandardScaler(with_mean=True, with_std=True) return sc.fit_transform(pipeline.processed_tr)
[docs] class TnormStep(PreprocStep): """Temporal normalization""" def __init__(self, *, save=False, Nv, Nt, suffix='.pp_Tnorm.nii', nvols_to_compute=50): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.nvols_to_compute = nvols_to_compute self.mean_removed = False self.fwhm = None self.orig_data = np.zeros((self.Nv, self.nvols_to_compute)) self.baseline_signal = None def _start(self, pipeline): smooth_step = next((step for step in pipeline.steps if step.name == StepType.SMOOTH.value), None) if smooth_step: self.fwhm = smooth_step.fwhm #TODO: determine if iGLM presence with num_polorts >= 1 should also count here self.mean_removed = StepType.EMA.value in pipeline.steps def _run(self, pipeline): n = pipeline.n t = pipeline.t # Calculate SPC after baseline is established if self.baseline_signal is not None: current_signal = pipeline.processed_tr[:, 0] return calculate_spc(current_signal, self.baseline_signal, self.mean_removed) # Collect orig data until nvols_to_compute is reached elif 0 < n <= self.nvols_to_compute: vol = pipeline.Data_FromAFNI[:, pipeline.t] if self.fwhm is not None: # Smooth if necessary vol = rt_smooth_vol(vol[:, np.newaxis], pipeline.mask_img, fwhm=self.fwhm) self.orig_data[:, n - 1] = vol[:, 0] # Establish baseline if n == self.nvols_to_compute: print(f"++ INFO: establishing baseline at {t=}") self.baseline_signal = np.mean(self.orig_data, axis=1) self.baseline_signal[self.baseline_signal == 0] = 1e-6 # Avoid divide by zero print(f'{self.baseline_signal.shape=}') return pipeline.processed_tr # During discard volumes return pipeline.processed_tr
[docs] class WindowingStep(PreprocStep): def __init__(self, *, save=False, suffix='.pp_Windowed.nii', Nv, Nt, win_length=4): super().__init__(save=save, suffix=suffix, Nv=Nv, Nt=Nt) self.buffer = None self.win_length = win_length win = exponential(self.win_length, center=0, tau=3, sym=False) self.win_weights = win[:, np.newaxis] # NOTE: windowing used to only be done once matching began, now starting it along with all other preproc steps. def _run(self, pipeline): if self.buffer is None: self.buffer = CircularBuffer(pipeline.processed_tr.shape[0], self.win_length) current_window = self.buffer.update(pipeline.processed_tr) if current_window is not None: return np.dot(current_window, self.win_weights) return pipeline.processed_tr