Source code for bsb.simulation.targetting

import functools
import math
import typing

import numpy as np
from numpy.random import default_rng

from .. import config
from ..config import refs, types

if typing.TYPE_CHECKING:
    from ..cell_types import CellType
    from .cell import CellModel


[docs] @config.dynamic(attr_name="strategy", default="all", auto_classmap=True) class Targetting: type: typing.Literal["cell"] | typing.Literal["connection"] = config.attr( type=types.in_(["cell", "connection"]), default="cell" )
[docs] def get_targets(self, adapter, simulation, simdata): if self.type == "cell": return simdata.populations elif self.type == "connection": return simdata.connections
[docs] @config.node class CellTargetting(Targetting, classmap_entry="all"): @config.property def type(self): return "cell"
[docs] def get_targets(self, adapter, simulation, simdata): return simdata.populations
[docs] @config.node class ConnectionTargetting(Targetting, classmap_entry="all_connections"): @config.property def type(self): return "connection"
[docs] def get_targets(self, adapter, simulation, simdata): return simdata.connections
[docs] class CellModelFilter: cell_models: list["CellModel"] = config.reflist( refs.sim_cell_model_ref, required=False )
[docs] def get_targets(self, adapter, simulation, simdata): return { model: pop for model, pop in simdata.populations.items() if not self.cell_models or model in self.cell_models }
class CellTypeFilter: cell_types: list["CellType"] = config.reflist(refs.cell_type_ref, required=False) only_local: bool = config.attr(type=bool, default=True) def get_targets(self, adapter, simulation, simdata): chunks = simdata.chunks if self.only_local else None return { cell_name: cell_type.get_placement_set(chunks=chunks) for cell_name, cell_type in simulation.scaffold.cell_types.items() if not self.cell_types or cell_type in self.cell_types }
[docs] class FractionFilter: count = config.attr( type=int, required=types.mut_excl("fraction", "count", required=False) ) fraction = config.attr( type=types.fraction(), required=types.mut_excl("fraction", "count", required=False), )
[docs] def satisfy_fractions(self, targets): return {model: self._frac(data) for model, data in targets.items()}
def _frac(self, data): take = None if self.count is not None: take = self.count if self.fraction is not None: take = math.floor(len(data) * self.fraction) if take is None: return data else: # Select `take` elements from data with a boolean mask (otherwise a sorted # integer mask would be required) idx = np.zeros(len(data), dtype=bool) idx[np.random.choice(len(data), take, replace=False)] = True return data[idx]
[docs] @staticmethod def filter(f): @functools.wraps(f) def wrapper(self, *args, **kwargs): return self.satisfy_fractions(f(self, *args, **kwargs)) return wrapper
[docs] @config.node class CellModelTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="cell_model" ): """ Targets all cells of certain cell models. """ cell_models: list["CellModel"] = config.reflist( refs.sim_cell_model_ref, required=True )
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): return super().get_targets(adapter, simulation, simdata)
[docs] @config.node class RepresentativesTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="representatives" ): """ Targets all identifiers of certain cell types. """ n: int = config.attr(type=int, default=1)
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): return { model: default_rng().choice(len(pop), size=self.n, replace=False) for model, pop in super().get_targets(adapter, simulation, simdata) }
[docs] @config.node class ByIdTargetting(FractionFilter, CellTargetting, classmap_entry="by_id"): """ Targets all given identifiers. """ ids: dict[str, list[int]] = config.attr( type=types.dict(type=types.list(type=int)), required=True )
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): by_name = { model.name: model for model, pop in simdata.populations.items() if len(pop) > 0 } dict_target = {} for model_name, ids in self.ids.items(): if (model := by_name.get(model_name)) is not None: pop = simdata.populations[model] my_ids = simdata.placement[model].convert_to_local(ids) dict_target[model] = pop[my_ids] return dict_target
[docs] @config.node class ByLabelTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="by_label" ): """ Targets all given labels. """ labels: list[str] = config.attr(type=types.list(type=str), required=True)
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): return { model: simdata.populations[model][ simdata.placement[model].get_label_mask(self.labels) ] for model in super().get_targets(adapter, simulation, simdata) }
[docs] @config.node class CylindricalTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="cylinder" ): """ Targets all cells in a cylinder along specified axis. """ origin: np.ndarray[float] = config.attr(type=types.ndarray(shape=(2,), dtype=float)) """ Coordinates of the base of the cylinder for each non main axis. """ axis: typing.Literal["x"] | typing.Literal["y"] | typing.Literal["z"] = config.attr( type=types.in_(["x", "y", "z"]), default="y" ) """ Main axis of the cylinder. """ radius: float = config.attr(type=float, required=True) """ Radius of the cylinder. """
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): """ Target all or certain cells within a cylinder of specified radius. """ if self.axis == "x": axes = [1, 2] elif self.axis == "y": axes = [0, 2] else: axes = [0, 1] return { model: simdata.populations[model][ np.sum( (simdata.placement[model].load_positions()[:, axes] - self.origin) ** 2, axis=1, ) < self.radius**2 ] for model in super().get_targets(adapter, simulation, simdata) }
@config.node class SphericalTargettingCellTypes( CellTypeFilter, FractionFilter, Targetting, classmap_entry="sphere_cell_types" ): """ Targets all cell types in a sphere. """ origin: list[float] = config.attr(type=types.list(type=float, size=3), required=True) radius: float = config.attr(type=float, required=True) @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): """ Target all or certain cells within a sphere of specified radius. """ return { model: ps.load_ids()[ ( np.sum( (ps.load_positions() - self.origin) ** 2, axis=1, ) < self.radius**2 ) ] for model, ps in super().get_targets(adapter, simulation, simdata).items() }
[docs] @config.node class SphericalTargetting( CellModelFilter, FractionFilter, CellTargetting, classmap_entry="sphere" ): """ Targets all cells in a sphere. """ origin: list[float] = config.attr(type=types.list(type=float, size=3), required=True) radius: float = config.attr(type=float, required=True)
[docs] @FractionFilter.filter def get_targets(self, adapter, simulation, simdata): """ Target all or certain cells within a sphere of specified radius. """ return { model: simdata.populations[model][ ( np.sum( (simdata.placement[model].load_positions() - self.origin) ** 2, axis=1, ) < self.radius**2 ) ] for model in super().get_targets(adapter, simulation, simdata) }
[docs] @config.dynamic( attr_name="strategy", default="everywhere", auto_classmap=True, classmap_entry="everywhere", ) class LocationTargetting:
[docs] def get_locations(self, cell): return [v for v in cell.locations.values()]
[docs] @config.node class SomaTargetting(LocationTargetting, classmap_entry="soma"):
[docs] def get_locations(self, cell): return [cell.locations[(0, 0)]]
[docs] @config.node class LabelTargetting(LocationTargetting, classmap_entry="label"): labels = config.list(required=True)
[docs] def get_locations(self, cell): locs = [ loc for loc in cell.locations.values() if all(l_ in loc.section.labels for l_ in self.labels) ] return locs
[docs] @config.node class BranchLocTargetting(LabelTargetting, classmap_entry="branch"): x = config.attr(type=types.fraction(), default=0.5)
[docs] def get_locations(self, cell): locations = super().get_locations(cell) branches = set() selected = [] for loc in locations: if ( loc._loc[0] not in branches and loc.arc(0) <= self.x and loc.arc(1) > self.x ): selected.append(loc) branches.add(loc._loc[0]) return selected
__all__ = [ "BranchLocTargetting", "ByIdTargetting", "ByLabelTargetting", "CellModelFilter", "CellModelTargetting", "CellTargetting", "ConnectionTargetting", "CylindricalTargetting", "FractionFilter", "LabelTargetting", "LocationTargetting", "RepresentativesTargetting", "SomaTargetting", "SphericalTargetting", "Targetting", ]