Source code for rtcog.matching.matcher

import sys
import pickle
import numpy as np

from rtcog.matching.matching_opts import MatchingOpts
from rtcog.utils.log import get_logger
from rtcog.utils.shared_memory_manager import SharedMemoryManager
from rtcog.utils.sync import SyncEvents

log = get_logger()


# TODO: accept SyncEvents instead of individual mp events
[docs] class Matcher: """ Base class for matching processed TR data to given templates. This class provides the framework for comparing incoming fMRI volumes against predefined brain state templates to detect patterns of interest. Subclasses implement specific matching algorithms (e.g., SVR-based or mask-based). Attributes ---------- registry : dict Class-level registry mapping matcher names to their classes. match_start : int First volume index to start computing match scores. Nt : int Total number of time points in the experiment. scores : np.ndarray Array of match scores, shape (Ntemplates, Nt). Ntemplates : int Number of templates to match against. mp_end : multiprocessing.Event Event to signal experiment end. mp_new_tr : multiprocessing.Event Event set when a new TR is processed. mp_shm_ready : multiprocessing.Event Event indicating shared memory is ready. Methods ------- from_name(name) Factory method to instantiate a matcher by name. match(t, n, tr_data) Compute similarity scores for a TR and update shared memory. setup_shared_memory() Initialize shared memory for score storage. cleanup_shared_memory() Clean up shared memory resources. _match(tr_data) Abstract method for computing match scores (implemented by subclasses). """ registry = {} # Holds all available matching classes def __init__(self, match_opts: MatchingOpts, Nt: int, sync: SyncEvents, match_path: str): """ Initialize the Matcher. Parameters ---------- match_opts : MatchingOpts Configuration options for matching. Nt : int Total number of time points. sync : SyncEvents Synchronization events container. match_path : str Path to matching data (e.g., templates or model). """ self.match_start = match_opts.match_start # First volume to start computing match scores on self.Nt = Nt self.scores = None self.Ntemplates = None self.mp_end = sync.end self.mp_new_tr = sync.new_tr self.mp_shm_ready = sync.shm_ready def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Skip abstract or helper base classes if cls.__name__ == "Matcher" or cls.__name__.startswith("_"): return # Strip "Matcher" from the end of class names name = cls.__name__ if name.endswith("Matcher"): name = name[:-7] name = name.lower() cls.registry[name] = cls
[docs] @classmethod def from_name(cls, name): if name not in cls.registry: raise ValueError(f'Unknown matching method: {name}') return cls.registry[name]
[docs] def match(self, t, n, tr_data): """ Compute similarity scores for a TR and update shared memory. Parameters ---------- t : int Time point index. n : int Processed volume index. tr_data : np.ndarray Processed TR data. Returns ------- np.ndarray Updated scores array. """ if self.scores is None: self.scores = np.zeros((self.Ntemplates, self.Nt)) this_t_scores = self._match(tr_data) if this_t_scores.ndim != 1: raise ValueError( f"{self.__class__.__name__}._match() must return 1D array; " f"got shape {this_t_scores.shape}" ) if this_t_scores.shape[0] != self.Ntemplates: raise ValueError( f"{self.__class__.__name__}._match() returned {this_t_scores.shape[0]} " f"scores, expected {self.Ntemplates}" ) self.scores[:, t] = this_t_scores self.shared_arr[:, t] = this_t_scores self.mp_new_tr.set() log.debug(f'[t={t},n={n}] Online - Matching - scores.shape {self.scores.shape}') return self.scores
[docs] def setup_shared_memory(self): """ Initialize shared memory for score storage. Creates a shared memory buffer to pass match scores to the data streaming process. """ if self.Ntemplates is None: raise RuntimeError("Ntemplates must be set before creating shared memory") base_arr = np.zeros((self.Ntemplates, self.Nt), dtype=np.float32) self.shm_manager = SharedMemoryManager("match_scores", create=True, size=base_arr.nbytes) self.shm = self.shm_manager.open() self.shared_arr = np.ndarray(base_arr.shape, dtype=base_arr.dtype, buffer=self.shm.buf)
[docs] def cleanup_shared_memory(self): """ Clean up shared memory resources. """ if hasattr(self, 'shm_manager'): self.shm_manager.cleanup() if hasattr(self, 'shm_manager'): self.shm_manager.cleanup()
[docs] def _match(self, tr_data): """ Abstract method for computing match scores. Parameters ---------- tr_data : np.ndarray Processed TR data. Returns ------- np.ndarray 1D array of match scores for each template. """ raise NotImplementedError
[docs] class SVRMatcher(Matcher): """ Match to templates using pretrained SVR model. This matcher uses a support vector regression model trained on previous data to detect activation patterns in incoming TRs. """ def __init__(self, match_opts, Nt, sync, match_path): super().__init__(match_opts, Nt, sync, match_path) if match_path is None: self.mp_end.set() raise ValueError('SVR Model not provided.') try: with open(match_path, "rb") as f: self.input = pickle.load(f) except Exception as e: self.mp_end.set() raise RuntimeError(f'Unable to open SVR model pickle file: {e}') self.Ntemplates = len(self.input.keys()) self.template_labels = list(self.input.keys()) log.info(f'List of templates to be tested: {self.template_labels}') self.setup_shared_memory() self.mp_shm_ready.set() def _match(self, tr_data): out = [] data = np.squeeze(tr_data) for label in self.template_labels: out.append(self.input[label].predict(data[:,np.newaxis].T)[0]) return np.array(out)
[docs] class MaskMatcher(Matcher): """Match to templates using pretrained Mask Method""" def __init__(self, match_opts, Nt, sync, match_path): super().__init__(match_opts, Nt, sync, match_path) if match_path is None: self.mp_end.set() raise ValueError('Template info for match method not provided.') try: self.input = np.load(match_path, allow_pickle=True) except Exception as e: self.mp_end.set() raise RuntimeError(f'Error loading mask method file: {e}') self.template_labels = list(self.input["labels"]) self.Ntemplates = len(self.template_labels) log.info(f'List of templates to be tested: {self.template_labels}') self.masked_templates = self.input["masked_templates"].item() self.masks = self.input["masks"].item() self.voxel_counts = self.input["voxel_counts"].item() self.setup_shared_memory() self.mp_shm_ready.set() def _match(self, tr_data): out = [] for name in self.template_labels: mask = self.masks[name] template = self.masked_templates[name] masked_data = np.squeeze(tr_data)[mask] out.append(np.dot(template, masked_data) / self.voxel_counts[name]) return np.array(out)