################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################
import numpy as np
from .deterministic import Deterministic
from .node import Moments
from bayespy.utils import misc
[docs]class Take(Deterministic):
"""
Choose elements/sub-arrays along a plate axis
Basically, applies `np.take` on a plate axis. Allows advanced mapping of
plates.
Parameters
----------
node : Node
A node to apply the take operation on.
indices : array of integers
Plate elements to pick along a plate axis.
plate_axis : int (negative)
The plate axis to pick elements from (default: -1).
See also
--------
numpy.take
Examples
--------
>>> from bayespy.nodes import Gamma, Take
>>> alpha = Gamma([1, 2, 3], [1, 1, 1])
>>> x = Take(alpha, [1, 1, 2, 2, 1, 0])
>>> x.get_moments()[0]
array([2., 2., 3., 3., 2., 1.])
"""
[docs] def __init__(self, node, indices, plate_axis=-1, **kwargs):
self._moments = node._moments
self._parent_moments = (node._moments,)
self._indices = np.array(indices)
self._plate_axis = plate_axis
self._original_length = node.plates[plate_axis]
# Validate arguments
if not misc.is_scalar_integer(plate_axis):
raise ValueError("Plate axis must be integer")
if plate_axis >= 0:
raise ValueError("plate_axis must be negative index")
if plate_axis < -len(node.plates):
raise ValueError("plate_axis out of bounds")
if not issubclass(self._indices.dtype.type, np.integer):
raise ValueError("Indices must be integers")
if (np.any(self._indices < -self._original_length) or
np.any(self._indices >= self._original_length)):
raise ValueError("Index out of bounds")
super().__init__(node, dims=node.dims, **kwargs)
def _compute_moments(self, u_parent):
u = []
for (ui, dimi) in zip(u_parent, self.dims):
axis = self._plate_axis - len(dimi)
# Just in case the taken axis is using broadcasting and has unit
# length in u_parent, force it to have the correct length along the
# axis in order to avoid errors in np.take.
broadcaster = np.ones((self._original_length,) + (-axis-1)*(1,))
u.append(np.take(ui*broadcaster, self._indices, axis=axis))
return u
def _compute_message_to_parent(self, index, m_child, u_parent):
m = [
misc.put_simple(
mi,
self._indices,
axis=self._plate_axis-len(dimi),
length=self._original_length,
)
for (mi, dimi) in zip(m_child, self.dims)
]
return m
def _compute_weights_to_parent(self, index, weights):
return misc.put_simple(
weights,
self._indices,
axis=self._plate_axis,
length=self._original_length,
)
def _compute_plates_to_parent(self, index, plates):
# Number of axes created by take operation
N = np.ndim(self._indices)
if self._plate_axis >= 0:
raise RuntimeError("Plate axis should be negative")
end_before = self._plate_axis - N + 1
start_after = self._plate_axis + 1
if end_before == 0:
return plates + (self._original_length,)
elif start_after == 0:
return plates[:end_before] + (self._original_length,)
return (plates[:end_before]
+ (self._original_length,)
+ plates[start_after:])
def _compute_plates_from_parent(self, index, parent_plates):
plates = parent_plates[:self._plate_axis] + np.shape(self._indices)
if self._plate_axis != -1:
plates = plates + parent_plates[(self._plate_axis+1):]
return plates
def _compute_plates_multiplier_from_parent(self, index, parent_multiplier):
if any(p != 1 for p in parent_multiplier):
raise NotImplementedError(
"Take node doesn't yet support plate multipliers {0}"
.format(parent_multiplier)
)
return parent_multiplier