Source code for bayespy.inference.vmp.nodes.concatenate
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################
import numpy as np
from bayespy.utils import misc
from .deterministic import Deterministic
from .node import Moments
[docs]class Concatenate(Deterministic):
"""
Concatenate similar nodes along a plate axis.
Nodes must be of same type and dimensionality. Also, plates must be
identical except for the plate axis along which the concatenation is
performed.
See also
--------
numpy.concatenate
"""
[docs] def __init__(self, *nodes, axis=-1, **kwargs):
if axis >= 0:
raise ValueError("Currently, only negative axis indeces "
"are allowed.")
self._axis = axis
parent_moments = None
for node in nodes:
try:
parent_moments = node._moments
except:
pass
else:
break
if parent_moments is None:
raise ValueError("Couldn't determine parent moments")
# All parents must have same moments
self._parent_moments = (parent_moments,) * len(nodes)
self._moments = parent_moments
# Convert nodes
try:
nodes = [
self._ensure_moments(
node,
parent_moments.__class__,
**parent_moments.get_instance_conversion_kwargs()
)
for node in nodes
]
except Moments.NoConverterError:
raise ValueError("Parents have different moments")
# Dimensionality of the node
dims = tuple([dim for dim in nodes[0].dims])
for node in nodes:
if node.dims != dims:
raise ValueError("Parents have different dimensionalities")
super().__init__(
*nodes,
dims=dims,
allow_dependent_parents=True, # because parent plates are kept separate
**kwargs
)
# Compute start indices for each parent on the concatenated plate axis
self._indices = np.zeros(len(nodes)+1, dtype=np.int64)
self._indices[1:] = np.cumsum([int(parent.plates[axis])
for parent in self.parents])
self._lengths = [parent.plates[axis] for parent in self.parents]
return
def _get_id_list(self):
"""
Parents don't need to be independent for this node so remove duplicates
"""
return list(set(super()._get_id_list()))
def _compute_plates_to_parent(self, index, plates):
plates = list(plates)
plates[self._axis] = self.parents[index].plates[self._axis]
return tuple(plates)
def _compute_plates_from_parent(self, index, plates):
plates = list(plates)
plates[self._axis] = 0
for parent in self.parents:
plates[self._axis] += parent.plates[self._axis]
return tuple(plates)
def _plates_multiplier_from_parent(self, index):
multipliers = [parent.plates_multiplier for parent in self.parents]
for m in multipliers:
if np.any(np.array(m) != 1):
raise ValueError("Concatenation node does not support plate "
"multipliers.")
return ()
def _compute_weights_to_parent(self, index, weights):
axis = self._axis
indices = self._indices[index:(index+1)]
if np.ndim(weights) >= abs(axis) and np.shape(weights)[axis] > 1:
# Take the middle one of the returned three arrays
return np.split(weights, indices, axis=axis)[1]
else:
return weights
def _compute_message_to_parent(self, index, m, *u_parents):
msg = []
indices = self._indices[index:(index+2)]
for i in range(len(m)):
# Fix plate axis to array axis
axis = self._axis - len(self.dims[i])
# Find the slice from the message
if np.ndim(m[i]) >= abs(axis) and np.shape(m[i])[axis] > 1:
mi = np.split(m[i], indices, axis=axis)[1]
else:
mi = m[i]
msg.append(mi)
return msg
def _compute_moments(self, *u_parents):
# TODO/FIXME: Unfortunately, np.concatenate doesn't support
# broadcasting but moment messages may use broadcasting.
#
# WORKAROUND: Broadcast the arrays explcitly to have same shape
# except for the concatenated axis.
u = []
for i in range(len(self.dims)):
# Fix plate axis to array axis
axis = self._axis - len(self.dims[i])
# Find broadcasted shape
ui_parents = [u_parent[i] for u_parent in u_parents]
shapes = [list(np.shape(uip)) for uip in ui_parents]
for i in range(len(shapes)):
if len(shapes[i]) >= abs(axis):
shapes[i][axis] = 1
## shapes = [np.shape(uip[:axis]) + (1,) + np.shape(uip[(axis+1)])
## if np.ndim(uip) >= abs(self._axis) else
## np.shape(uip)
## for uip in ui_parents]
bc_shape = misc.broadcasted_shape(*shapes)
# Concatenated axis must be broadcasted explicitly
bc_shapes = [misc.broadcasted_shape(bc_shape,
(length,) + (1,)*(abs(axis)-1))
for length in self._lengths]
# Broadcast explicitly
ui_parents = [uip * np.ones(shape)
for (uip, shape) in zip(ui_parents, bc_shapes)]
# Concatenate
ui = np.concatenate(ui_parents, axis=axis)
u.append(ui)
return u