import plotly.graph_objects as go
from plotly.subplots import make_subplots
from .networks import all_depth_first_branches, get_branch_points, reduce_branch
import numpy as np, math, functools
from .morphologies import Compartment
from contextlib import contextmanager
import random, types
from .reporting import warn
class CellTrace:
def __init__(self, meta, data):
self.meta = meta
self.data = data
self.color = None
class CellTraces:
def __init__(self, id, title, order=None):
self.traces = []
self.cell_id = id
self.title = title
self.order = order
def add(self, meta, data):
self.traces.append(CellTrace(meta, data))
def __iter__(self):
return iter(self.traces)
def __len__(self):
return len(self.traces)
class CellTraceCollection:
def __init__(self, cells=None):
if cells is None:
cells = {}
elif isinstance(cells, list):
cells = dict(map(lambda cell: (cell.cell_id, cell), cells))
self.cells = cells
self.legends = []
self.colors = []
def set_legends(self, legends):
self.legends = legends
def set_colors(self, colors):
self.colors = colors
def add(self, id, meta, data):
if not id in self.cells:
self.cells[id] = CellTraces(
id,
meta.get("display_label", "Cell " + str(id)),
order=meta.get("order", None),
)
self.cells[id].add(meta, data)
def __iter__(self):
return iter(self.cells.values())
def __len__(self):
return len(self.cells)
def order(self):
self.cells = dict(sorted(self.cells.items(), key=lambda t: t[1].order or 0))
def reorder(self, order):
for o, key in zip(iter(order), self.cells.keys()):
self.cells[key].order = o
self.order()
def _figure(f):
"""
Decorator for functions that produce a Figure. Can set defaults, create and show
figures and disable the legend.
Adds the `show` and `legend` keyword arguments.
"""
@functools.wraps(f)
def wrapper_function(*args, fig=None, show=True, legend=True, **kwargs):
if fig is None:
fig = go.Figure()
r = f(*args, fig=fig, show=show, legend=legend, **kwargs)
fig.update_layout(showlegend=legend)
if show:
fig.show()
return r
return wrapper_function
def _network_figure(f):
"""
Decorator for functions that produce a Figure of a network. Applies ``@_figure``
and can create cubic perspective and swap the Y & Z axis labels.
Adds the `cubic` and `swapaxes` keyword arguments.
"""
@functools.wraps(f)
@_figure
def wrapper_function(*args, fig=None, cubic=True, swapaxes=True, **kwargs):
r = f(*args, fig=fig, cubic=cubic, swapaxes=swapaxes, **kwargs)
if cubic:
fig.update_layout(scene_aspectmode="cube")
if swapaxes:
axis_labels = dict(xaxis_title="X", yaxis_title="Z", zaxis_title="Y")
else:
axis_labels = dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z")
fig.update_layout(scene=axis_labels)
return r
return wrapper_function
def _morpho_figure(f):
"""
Decorator for functions that produce a Figure of a morphology. Applies ``@_figure``
and can set the offset, range & aspectratio and can swap the Y & Z axis labels.
Adds the `offset`, `set_range` and `swapaxes` keyword arguments.
"""
@functools.wraps(f)
@_figure
def wrapper_function(
morphology,
*args,
offset=None,
set_range=True,
fig=None,
swapaxes=True,
soma_radius=None,
**kwargs,
):
if offset is None:
offset = [0.0, 0.0, 0.0]
r = f(
morphology,
*args,
fig=fig,
offset=offset,
set_range=set_range,
swapaxes=swapaxes,
soma_radius=soma_radius,
**kwargs,
)
if set_range:
rng = get_morphology_range(morphology, offset=offset, soma_radius=soma_radius)
set_scene_range(fig.layout.scene, rng)
set_scene_aspect(fig.layout.scene, rng)
if swapaxes:
axis_labels = dict(xaxis_title="X", yaxis_title="Z", zaxis_title="Y")
else:
axis_labels = dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z")
fig.update_layout(scene=axis_labels)
return r
return wrapper_function
def _input_highlight(f, required=False):
"""
Decorator for functions that highlight an input region on a Figure.
Adds the `input_region` keyword argument. Decorated function has to have a `fig`
keyword argument.
:param required: If set to True, an ArgumentError is thrown if no `input_region`
is specified
:type required: bool
"""
@functools.wraps(f)
def wrapper_function(*args, fig=None, input_region=None, **kwargs):
r = f(*args, fig=fig, **kwargs)
if input_region is not None:
shapes = [
dict(
type="rect",
xref="x",
yref="paper",
x0=input_region[0],
y0=0,
x1=input_region[1],
y1=1,
fillcolor="#d3d3d3",
opacity=0.3,
line=dict(width=0),
)
]
fig.update_layout(shapes=shapes)
elif required:
raise ArgumentError("Missing required keyword argument `input_region`.")
return r
return wrapper_function
def _plot_network(network, fig, cubic, swapaxes):
xmin, xmax, ymin, ymax, zmin, zmax = tuple([0] * 6)
for type in network.configuration.cell_types.values():
if type.entity:
continue
pos = network.cells_by_type[type.name][:, [2, 3, 4]]
color = type.plotting.color
fig.add_trace(
go.Scatter3d(
x=pos[:, 0],
y=pos[:, 1 if not swapaxes else 2],
z=pos[:, 2 if not swapaxes else 1],
mode="markers",
marker=dict(color=color, size=type.placement.radius),
opacity=type.plotting.opacity,
name=type.plotting.label,
)
)
xmin = min(xmin, np.min(pos[:, 0], initial=0))
xmax = max(xmax, np.max(pos[:, 0], initial=0))
ymin = min(ymin, np.min(pos[:, 1], initial=0))
ymax = max(ymax, np.max(pos[:, 1], initial=0))
zmin = min(zmin, np.min(pos[:, 2], initial=0))
zmax = max(zmax, np.max(pos[:, 2], initial=0))
if cubic:
rng = max(xmax - xmin, ymax - ymin, zmax - zmin)
fig.layout.scene.xaxis.range = [xmin, xmin + rng]
if swapaxes:
fig.layout.scene.yaxis.range = [ymin, ymin + rng]
fig.layout.scene.zaxis.range = [zmin, zmin + rng]
else:
fig.layout.scene.yaxis.range = [ymin, ymin + rng]
fig.layout.scene.zaxis.range = [zmin, zmin + rng]
[docs]@_network_figure
def plot_network(
network, fig=None, cubic=True, swapaxes=True, show=True, legend=True, from_memory=True
):
"""
Plot a network, either from the current cache or the storage.
"""
if from_memory:
_plot_network(network, fig, cubic, swapaxes)
else:
network.reset_network_cache()
for type in network.configuration.cell_types.values():
if type.entity:
continue
# Load from HDF5
network.get_cells_by_type(type.name)
_plot_network(network, fig, cubic, swapaxes)
return fig
@_network_figure
def network_figure(fig=None, **kwargs):
return fig
@_network_figure
def plot_detailed_network(
network, fig=None, cubic=True, swapaxes=True, show=True, legend=True, ids=None
):
from .output import MorphologyRepository
ms = MorphologyScene(fig)
mr = network.morphology_repository
for cell_type in network.configuration.cell_types.values():
segment_radius = 1.0
if cell_type.name != "granule_cell":
segment_radius = 2.5
m_names = cell_type.list_all_morphologies()
if len(m_names) == 0:
continue
if len(m_names) > 1:
raise NotImplementedError(
"We haven't implemented plotting different morphologies per cell type yet. Open an issue if you need it."
)
cells = network.get_placement_set(cell_type.name).cells
morpho = mr.get_morphology(m_names[0])
for cell in cells:
if ids is not None and cell.id not in ids:
continue
ms.add_morphology(
morpho,
cell.position,
color=cell_type.plotting.color,
soma_radius=cell_type.placement.soma_radius,
segment_radius=segment_radius,
)
ms.prepare_plot()
scene = fig.layout.scene
scene.xaxis.range = [-200, 200]
scene.yaxis.range = [-200, 200]
scene.zaxis.range = [0, 600]
return fig
def get_voxel_cloud_traces(
cloud, selected_voxels=None, offset=[0.0, 0.0, 0.0], color=None
):
# Calculate the 3D voxel indices based on the voxel positions and the grid size.
boxes = cloud.get_boxes()
voxels = cloud.voxels.copy()
box_positions = np.column_stack(boxes[:, voxels])
# Calculate normalized occupancy of each voxel to determine transparency
occupancies = cloud.get_occupancies() / 1.5
if color is None:
color = [255, 0, 0]
color = list(map(str, color))
if color is None:
color = [0.0, 255.0, 0.0]
color = [str(c) for c in color]
colors = np.empty(voxels.shape, dtype=object)
if selected_voxels is not None:
# Color selected voxels
colors[voxels] = "rgba(0, 0, 0, 0.0)"
colors[selected_voxels] = "rgba(" + ",".join(color) + ", 1.0)"
else:
# Prepare voxels with the compartment density coded into the alpha of the facecolor
colors[voxels] = [
"rgba(" + ",".join(color) + ", {})".format(o) for o in occupancies
]
traces = []
for box, color in zip(box_positions, colors[voxels]):
box += offset
traces.extend(
plotly_block(box, [cloud.grid_size, cloud.grid_size, cloud.grid_size], color)
)
return traces
@_network_figure
def plot_voxel_cloud(
cloud,
selected_voxels=None,
fig=None,
show=True,
legend=True,
cubic=True,
swapaxes=True,
set_range=True,
color=None,
offset=[0.0, 0.0, 0.0],
):
traces = get_voxel_cloud_traces(
cloud, selected_voxels=selected_voxels, offset=offset, color=color
)
for trace in traces:
fig.add_trace(trace)
if set_range:
box = cloud.get_voxel_box()
range = [min(box), max(box)]
fig.layout.scene.xaxis.range = range + offset[0]
fig.layout.scene.yaxis.range = range + offset[2] if swapaxes else offset[1]
fig.layout.scene.zaxis.range = range + offset[1] if swapaxes else offset[2]
return fig
def get_branch_trace(compartments, offset=[0.0, 0.0, 0.0], color="black", width=1.0):
if width == 0:
x, y, z = [], [], []
else:
x = [c.start[0] + offset[0] for c in compartments]
y = [c.start[1] + offset[1] for c in compartments]
z = [c.start[2] + offset[2] for c in compartments]
# Add branch endpoint
x.append(compartments[-1].end[0] + offset[0])
y.append(compartments[-1].end[1] + offset[1])
z.append(compartments[-1].end[2] + offset[2])
return go.Scatter3d(
x=x, y=z, z=y, mode="lines", line=dict(width=width, color=color), showlegend=False
)
def get_soma_trace(
soma_radius, offset=[0.0, 0.0, 0.0], color="black", opacity=1, steps=5, **kwargs
):
phi = np.linspace(0, 2 * np.pi, num=steps * 2)
theta = np.linspace(-np.pi / 2, np.pi / 2, num=steps)
phi, theta = np.meshgrid(phi, theta)
x = np.cos(theta) * np.sin(phi) * soma_radius + offset[0]
y = np.cos(theta) * np.cos(phi) * soma_radius + offset[2]
z = np.sin(theta) * soma_radius + offset[1]
return go.Mesh3d(
x=x.flatten(),
y=y.flatten(),
z=z.flatten(),
opacity=opacity,
color=color,
alphahull=0,
**kwargs,
)
@_network_figure
def plot_fiber_morphology(
fiber,
offset=[0.0, 0.0, 0.0],
fig=None,
cubic=True,
swapaxes=True,
show=True,
legend=True,
set_range=True,
color="black",
segment_radius=1.0,
):
def get_branch_traces(branches, traces):
for branch in branches:
traces.append(
get_branch_trace(
branch._compartments, offset, color=color, width=segment_radius
)
)
get_branch_traces(branch.child_branches, traces)
traces = []
get_branch_traces(fiber.root_branches, traces)
for trace in traces:
fig.add_trace(trace)
return fig
@_morpho_figure
def plot_morphology(
morphology,
offset=None,
fig=None,
swapaxes=True,
show=True,
legend=True,
set_range=True,
color="black",
reduce_branches=False,
soma_radius=None,
soma_opacity=1.0,
segment_radius=1.0,
use_last_soma_comp=True,
):
compartments = np.array(morphology.compartments.copy())
dfs_list = all_depth_first_branches(morphology.get_compartment_network())
if reduce_branches:
branch_points = get_branch_points(dfs_list)
dfs_list = list(map(lambda b: reduce_branch(b, branch_points), dfs_list))
traces = []
for branch in dfs_list[::-1]:
branch_comps = compartments[branch]
width = _get_branch_width(branch_comps, segment_radius)
_color = _get_branch_color(branch_comps, color)
traces.append(get_branch_trace(branch_comps, offset, color=_color, width=width))
if isinstance(color, dict) and "soma" not in color:
raise Exception("Please specify a color for the `soma`.")
soma_color = color["soma"] if isinstance(color, dict) else color
soma_comps = [c for c in compartments if "soma" in c.labels]
# Negative bool = -1/0 (True: -1, last soma comp, False: 0, first soma comp)
soma_comp = soma_comps[-use_last_soma_comp]
traces.append(
get_soma_trace(
soma_radius if soma_radius is not None else soma_comp.radius,
offset + (soma_comp.end if use_last_soma_comp else soma_comp.start),
soma_color,
opacity=soma_opacity,
)
)
for trace in traces:
fig.add_trace(trace)
return fig
@_figure
def plot_intersections(
from_morphology,
from_pos,
to_morphology,
to_pos,
intersections,
offset=[0.0, 0.0, 0.0],
fig=None,
show=True,
legend=True,
):
from_compartments = (
np.array(from_morphology.compartment_tree.get_arrays()[0])
+ np.array(offset)
+ np.array(from_pos)
)
to_compartments = (
np.array(to_morphology.compartment_tree.get_arrays()[0])
+ np.array(offset)
+ np.array(to_pos)
)
def _get_branch_width(branch, radii):
if isinstance(radii, dict):
for btype in reversed(branch[-1].labels):
if btype in radii:
return radii[btype]
raise Exception(
"Plotting width not specified for branches of type " + str(branch[-1].labels)
)
return radii
def _get_branch_color(branch, colors):
if isinstance(colors, dict):
for btype in reversed(branch[-1].labels):
if btype in colors:
return colors[btype]
raise Exception(
"Plotting color not specified for branches of type " + str(branch[-1].labels)
)
return colors
def plot_block(fig, origin, sizes, color=None, colorscale="Cividis", **kwargs):
edges, faces = plotly_block(origin, sizes, color, colorscale)
# fig.add_trace(edges, **kwargs)
fig.add_trace(faces, **kwargs)
def plotly_block(origin, sizes, color=None, colorscale_value=None, colorscale="Cividis"):
return plotly_block_edges(origin, sizes), plotly_block_faces(origin, sizes, color)
def plotly_block_faces(
origin,
sizes,
color=None,
colorscale_value=None,
colorscale="Cividis",
cmin=0,
cmax=16.0,
):
# 8 vertices of a block
x = origin[0] + np.array([0, 0, 1, 1, 0, 0, 1, 1]) * sizes[0]
y = origin[1] + np.array([0, 1, 1, 0, 0, 1, 1, 0]) * sizes[1]
z = origin[2] + np.array([0, 0, 0, 0, 1, 1, 1, 1]) * sizes[2]
color_args = {}
if colorscale_value:
color_args = {
"colorscale": colorscale,
"intensity": np.ones((8)) * colorscale_value,
"cmin": cmin,
"cmax": cmax,
}
if color:
color_args = {"color": color}
return go.Mesh3d(
x=x,
y=z,
z=y,
# i, j and k give the vertices of the mesh triangles
i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
opacity=0.3,
**color_args,
)
def plotly_block_edges(origin, sizes):
x = origin[0] + np.array([0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]) * sizes[0]
y = origin[1] + np.array([0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1]) * sizes[1]
z = origin[2] + np.array([0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0]) * sizes[2]
return go.Scatter3d(
x=x, y=z, z=y, mode="lines", line=dict(width=1.0, color="black"), showlegend=False
)
def set_scene_range(scene, bounds):
if hasattr(scene, "layout"):
scene = scene.layout.scene # Scene was a figure
scene.xaxis.range = bounds[0]
scene.yaxis.range = bounds[2]
scene.zaxis.range = bounds[1]
def set_scene_aspect(scene, bounds, mode="equal", swapaxes=True):
if mode == "equal":
ratios = np.array([d[1] - d[0] for d in bounds])
ratios = ratios / np.max(ratios)
items = zip(["x", "z", "y"] if swapaxes else ["x", "y", "z"], ratios)
scene.aspectratio = dict(items)
else:
scene.aspectmode = mode
[docs]def set_morphology_scene_range(scene, offset_morphologies):
"""
Set the range on a scene containing multiple morphologies.
:param scene: A scene of the figure. If the figure itself is given, ``figure.layout.scene`` will be used.
:param offset_morphologies: A list of tuples where the first element is offset and the 2nd is the :class:`Morphology`
"""
bounds = np.array([get_morphology_range(m[1], m[0]) for m in offset_morphologies])
combined_bounds = np.array(
list(zip(np.min(bounds, axis=0)[:, 0], np.max(bounds, axis=0)[:, 1]))
)
span = max(map(lambda b: b[1] - b[0], combined_bounds))
combined_bounds[:, 1] = combined_bounds[:, 0] + span
set_scene_range(scene, combined_bounds)
def get_morphology_range(morphology, offset=None, soma_radius=None):
if offset is None:
offset = [0.0, 0.0, 0.0]
r = soma_radius or 0.0
itr = enumerate(morphology.flatten(vectors=["x", "y", "z"]))
r = [[min(min(v), -r) + offset[i], max(max(v), r) + offset[i]] for i, v in itr]
return r
[docs]def hdf5_plot_spike_raster(
spike_recorders,
input_region=None,
show=True,
cutoff=0,
cell_type_sort=None,
cell_sort=None,
):
"""
Create a spike raster plot from an HDF5 group of spike recorders.
:param input_region: Specifies an interval ``[min, max]`` on the x axis to highlight
as active input to the simulation.
:type input_region: 2-element list-like
:param show: Immediately plot the result
:type show: bool
:param cutoff: Amount of ms initial simulation to ignore.
:type cutoff: float
:param cell_type_sort: A function to sort the cell types. Must take 2 dictionaries
as arguments, being the raster plot's x values per label and y values per label.
Must return an array labels matching those of the x and y values to order them.
:type cell_type_sort: function-like
:param cell_sort: A function that takes the cell type label and set of ids and returns
a map to sort them.
:type cell_sort: function-like
"""
x_labelled = {}
y_labelled = {}
colors = {}
for cell_id, dataset in spike_recorders.items():
attrs = dict(dataset.attrs)
if len(dataset.shape) == 1 or dataset.shape[1] == 1:
times = dataset[()] - cutoff
set_ids = np.ones(len(times)) * int(
attrs.get("cell_id", attrs.get("cell", cell_id))
)
else:
times = dataset[:, 1] - cutoff
set_ids = dataset[:, 0]
label = attrs.get("label", "unlabelled")
if not label in x_labelled:
x_labelled[label] = []
if not label in y_labelled:
y_labelled[label] = []
if not label in colors:
colors[label] = attrs.get("color", "black")
# Add the spike timings on the X axis.
x_labelled[label].extend(times)
# Set the cell id for the Y axis of each added spike timing.
y_labelled[label].extend(set_ids)
# Use the parallel arrays x & y to plot a spike raster
fig = go.Figure(
layout=dict(
xaxis=dict(title_text="Time [ms]"), yaxis=dict(title_text="Cell [ID]")
)
)
if cell_type_sort is None:
# Sorts the cell type dictionary by cell type size
cell_type_sort = lambda x, y: [
k for k, v in sorted(y.items(), key=lambda kv: len(kv[1]))
]
# This lambda maps each unique y value to a sorted index starting from 0
# We define this here so that it can be used as fallback mechanism later
_cell_sort = lambda l, sy: dict(zip(sy, np.argsort(sy)))
if cell_sort is None:
# If no cell sorter is given we use the fallback sorter as default sorter.
cell_sort = _cell_sort
sorted_labels = cell_type_sort(x_labelled, y_labelled)
start_id = 0
for label in sorted_labels:
x = np.array(x_labelled[label])
y = np.array(y_labelled[label])
if len(y) > 0:
uy = np.unique(y)
# Ask the cell sorter to give a map for the unique y values. If it returns
# something Falsy (such as None) we use the default cell sorter.
id_map = cell_sort(label, uy) or _cell_sort(label, uy)
len_diff = len(uy) - len(id_map)
if len_diff > 0:
warn(
f"Sorted '{label}' array do not contain all cell ids, {len_diff} {label} omitted from raster."
)
y_mask = np.isin(y, id_map.keys())
y = y[y_mask]
x = x[x_mask]
# Build a new numpy array using the `id_map` dictionary lookup
y = np.vectorize(id_map.__getitem__)(y) + start_id
start_id += len(uy)
plot_spike_raster(
x,
y,
label=label,
fig=fig,
show=False,
color=colors[label],
input_region=input_region,
)
fig.update_layout(xaxis=dict(range=[0, np.max(x, initial=0)]))
if show:
fig.show()
return fig
[docs]def hdf5_gdf_plot_spike_raster(spike_recorders, input_region=None, fig=None, show=True):
"""
Create a spike raster plot from an HDF5 group of spike recorders saved from NEST gdf files.
Each HDF5 dataset includes the spike timings of the recorded cell populations, with spike
times in the first row and neuron IDs in the second row.
"""
cell_ids = [np.unique(spike_recorders[k][:, 1]) for k in spike_recorders.keys()]
x = {}
y = {}
colors = {}
ids = {}
for cell_id, dataset in spike_recorders.items():
data = dataset[:, 0]
neurons = dataset[:, 1]
attrs = dict(dataset.attrs)
label = attrs["label"]
colors[label] = attrs["color"]
if not label in x:
x[label] = []
if not label in y:
y[label] = []
if not label in colors:
colors[label] = attrs["color"]
if not label in ids:
ids[label] = 0
cell_id = ids[label]
ids[label] += 1
# Add the spike timings on the X axis.
x[label].extend(data)
# Set the cell id for the Y axis of each added spike timing.
y[label].extend(neurons)
subplots_fig = make_subplots(cols=1, rows=len(x), subplot_titles=list(x.keys()))
_min = float("inf")
_max = -float("inf")
for i, (c, t) in enumerate(x.items()):
_min = min(_min, np.min(np.array(t)))
_max = max(_max, np.max(np.array(t)))
subplots_fig.update_xaxes(range=[_min, _max])
# Overwrite the layout and grid of the single plot that is handed to us
# to turn it into a subplots figure.
fig._grid_ref = subplots_fig._grid_ref
fig._layout = subplots_fig._layout
for i, l in enumerate((x.keys())):
plot_spike_raster(
x[l],
y[l],
label=l,
fig=fig,
row=i + 1,
col=1,
show=False,
color=colors[l],
input_region=input_region,
**kwargs,
)
if show:
fig.show()
return fig
@_figure
@_input_highlight
def plot_spike_raster(
spike_timings,
cell_ids,
fig=None,
row=None,
col=None,
show=True,
legend=True,
label="Cells",
color=None,
):
fig.add_trace(
go.Scatter(
x=spike_timings,
y=cell_ids,
mode="markers",
marker=dict(symbol="square", size=2, color=color or "black"),
name=label,
),
row=row,
col=col,
)
def hdf5_gather_voltage_traces(handle, root, groups=None):
if not groups:
groups = [""]
traces = CellTraceCollection()
for group in groups:
path = root + group
# If an element of `groups` point to a single set, rather than a group
# catch the exception and construct a single element group from the single set
try:
iter = handle[path].items()
except AttributeError:
target = handle[path]
iter = ((group, target),)
path = root
for name, dataset in iter:
meta = {}
id = int(name.split(".")[0])
meta["id"] = id
meta["location"] = name
meta["group"] = path
for k, v in dataset.attrs.items():
meta[k] = v
traces.add(id, meta, dataset)
return traces
@_figure
@_input_highlight
def plot_traces(
traces, fig=None, show=True, legend=True, cutoff=0, range=None, x=None, **kwargs
):
traces.order()
subplots_fig = make_subplots(
cols=1,
rows=len(traces),
subplot_titles=[trace.title for trace in traces],
x_title="Time [ms]",
y_title="Membrane potential [mV]",
**kwargs,
)
# Save the data already in the given figure
_data = fig.data
for k in dir(subplots_fig):
v = getattr(subplots_fig, k)
if isinstance(v, types.MethodType):
# Unbind subplots_fig methods and bind to fig.
v = v.__func__.__get__(fig)
fig.__dict__[k] = v
# Restore the data
fig.data = _data
fig.update_layout(height=max(len(traces) * 130, 300))
legend_groups = set()
legends = traces.legends
if range is not None and x is not None:
x = np.array(x)
x = x[cutoff:]
mask = (x >= range[0]) & (x <= range[1])
x = x[mask]
for i, cell_traces in enumerate(traces):
for j, trace in enumerate(cell_traces):
showlegend = legends[j] not in legend_groups
data = trace.data[cutoff:]
if range is not None and x is not None:
data = data[mask]
fig.add_trace(
go.Scatter(
x=x,
y=data,
legendgroup=legends[j],
name=legends[j],
showlegend=showlegend,
mode="lines",
marker=dict(color=trace.color or traces.colors[j]),
),
col=1,
row=i + 1,
)
legend_groups.add(legends[j])
return fig
class PSTH:
def __init__(self):
self.rows = []
def add_row(self, row):
row.index = len(self.rows)
self.rows.append(row)
def ordered_rows(self):
return sorted(self.rows, key=lambda t: t.order or 0)
class PSTHStack:
def __init__(self, name, color):
self.name = name
self.color = str(color)
self._runs = set()
self.list = []
def extend(self, arr, run=0):
self.list.extend(arr[:, 1])
self._runs.add(run)
@property
def runs(self):
return len(self._runs)
class PSTHRow:
def __init__(self, name, color, order=0):
from colour import Color
self.name = name
color = Color(color) if color else Color(pick_for=random.random())
self.palette = list(color.range_to("black", 6))
self.stacks = {}
self.max = -float("inf")
self.order = order
def extend(self, arr, stack=None, run=0):
if stack not in self.stacks:
self.stacks[stack] = PSTHStack(
stack or self.name, self.palette[len(self.stacks)]
)
self.stacks[stack].extend(arr, run=run)
self.max = max(self.max, np.max(arr[:, 1])) if len(arr) > 0 else self.max
@_figure
def hdf5_plot_psth(
network, handle, duration=3, cutoff=0, start=0, fig=None, gaps=True, **kwargs
):
psth = PSTH()
row_map = {}
for g in handle.values():
l = g.attrs.get("label", "unlabelled")
cts = g.attrs.get("cell_types", [])
color = None
if cts:
if len(cts) > 1:
warn(
"Multiple cell types detected in a single dataset, can't perform proper PSTH"
)
ct = network.configuration.cell_types[cts[0]]
l = ct.plotting.label
color = ct.plotting.color
elif l in network.configuration.cell_types:
ct = network.configuration.cell_types[l]
l = ct.plotting.label
color = ct.plotting.color
if l not in row_map:
color = g.attrs.get("color", color)
order = g.attrs.get("order", 0)
row_map[l] = row = PSTHRow(l, color, order=order)
psth.add_row(row)
else:
row = row_map[l]
run_id = g.attrs.get("run_id", 0)
adjusted = g[()]
adjusted[:, 1] = adjusted[:, 1] - cutoff
row.extend(adjusted, stack=g.attrs.get("stack", None), run=run_id)
subplots_fig = make_subplots(
cols=1,
rows=len(psth.rows),
subplot_titles=[row.name for row in psth.ordered_rows()],
x_title=kwargs.get("x_title", "Time [ms]"),
y_title=kwargs.get("y_title", "Population firing rate [Hz]"),
)
for k in dir(subplots_fig):
if k == "data" or k == "_data":
# Don't overwrite data already on the fig
continue
v = getattr(subplots_fig, k)
if isinstance(v, types.MethodType):
# Unbind subplots_fig methods and bind to fig.
v = v.__func__.__get__(fig)
fig.__dict__[k] = v
# Align xaxis ranges to max of all rows
_max = -float("inf")
for i, row in enumerate(psth.rows):
_max = max(_max, row.max)
fig.update_xaxes(range=[start, _max])
fig.update_layout(title_text=kwargs.get("title", "PSTH"))
if not gaps:
fig.update_layout(bargap=0, bargroupgap=0)
cell_types = network.get_cell_types()
for i, row in enumerate(psth.ordered_rows()):
for name, stack in sorted(row.stacks.items(), key=lambda x: x[0]):
counts, bins = np.histogram(stack.list, bins=np.arange(start, _max, duration))
# Workaround of Workarounds for merging info in scaffold and in results
# Compares plotting colors to identify cell type ...
for cell_type in cell_types:
if cell_type.plotting.color.lower() == stack.color:
current_cell_type = cell_type
break
else:
raise Exception(
f"Couldn't link result group '{name or row.name}' to a network cell type."
)
cell_num_single_run = network.get_placed_count(current_cell_type.name)
cell_num = cell_num_single_run * (stack.runs)
if str(name).startswith("##"):
# Lazy way to order the stacks; Stack names can start with ## and a number
# and it will be sorted by name, but the ## and number are not displayed.
name = name[4:]
bar_kwargs = dict()
if not gaps:
bar_kwargs["marker_line_width"] = 0
trace = go.Bar(
x=bins,
y=counts / cell_num * 1000 / duration,
name=name or row.name,
marker=dict(color=stack.color),
**bar_kwargs,
)
fig.add_trace(trace, row=i + 1, col=1)
return fig
class MorphologyScene:
def __init__(self, fig=None):
self.fig = fig or go.Figure()
self._morphologies = []
def add_morphology(self, morphology, offset=[0.0, 0.0, 0.0], **kwargs):
self._morphologies.append((offset, morphology, kwargs))
def show(self):
self.prepare_plot()
self.fig.show()
def prepare_plot(self):
if len(self._morphologies) == 0:
raise MorphologyError("Cannot show empty MorphologyScene")
for o, m, k in self._morphologies:
plot_morphology(m, offset=o, show=False, set_range=False, fig=self.fig, **k)
set_morphology_scene_range(self.fig.layout.scene, self._morphologies)