Source code for bayespy.inference.vmp.nodes.take

################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################

import numpy as np

from .deterministic import Deterministic
from .node import Moments
from bayespy.utils import misc


[docs]class Take(Deterministic): """ Choose elements/sub-arrays along a plate axis Basically, applies `np.take` on a plate axis. Allows advanced mapping of plates. Parameters ---------- node : Node A node to apply the take operation on. indices : array of integers Plate elements to pick along a plate axis. plate_axis : int (negative) The plate axis to pick elements from (default: -1). See also -------- numpy.take Examples -------- >>> from bayespy.nodes import Gamma, Take >>> alpha = Gamma([1, 2, 3], [1, 1, 1]) >>> x = Take(alpha, [1, 1, 2, 2, 1, 0]) >>> x.get_moments()[0] array([2., 2., 3., 3., 2., 1.]) """
[docs] def __init__(self, node, indices, plate_axis=-1, **kwargs): self._moments = node._moments self._parent_moments = (node._moments,) self._indices = np.array(indices) self._plate_axis = plate_axis self._original_length = node.plates[plate_axis] # Validate arguments if not misc.is_scalar_integer(plate_axis): raise ValueError("Plate axis must be integer") if plate_axis >= 0: raise ValueError("plate_axis must be negative index") if plate_axis < -len(node.plates): raise ValueError("plate_axis out of bounds") if not issubclass(self._indices.dtype.type, np.integer): raise ValueError("Indices must be integers") if (np.any(self._indices < -self._original_length) or np.any(self._indices >= self._original_length)): raise ValueError("Index out of bounds") super().__init__(node, dims=node.dims, **kwargs)
def _compute_moments(self, u_parent): u = [] for (ui, dimi) in zip(u_parent, self.dims): axis = self._plate_axis - len(dimi) # Just in case the taken axis is using broadcasting and has unit # length in u_parent, force it to have the correct length along the # axis in order to avoid errors in np.take. broadcaster = np.ones((self._original_length,) + (-axis-1)*(1,)) u.append(np.take(ui*broadcaster, self._indices, axis=axis)) return u def _compute_message_to_parent(self, index, m_child, u_parent): m = [ misc.put_simple( mi, self._indices, axis=self._plate_axis-len(dimi), length=self._original_length, ) for (mi, dimi) in zip(m_child, self.dims) ] return m def _compute_weights_to_parent(self, index, weights): return misc.put_simple( weights, self._indices, axis=self._plate_axis, length=self._original_length, ) def _compute_plates_to_parent(self, index, plates): # Number of axes created by take operation N = np.ndim(self._indices) if self._plate_axis >= 0: raise RuntimeError("Plate axis should be negative") end_before = self._plate_axis - N + 1 start_after = self._plate_axis + 1 if end_before == 0: return plates + (self._original_length,) elif start_after == 0: return plates[:end_before] + (self._original_length,) return (plates[:end_before] + (self._original_length,) + plates[start_after:]) def _compute_plates_from_parent(self, index, parent_plates): plates = parent_plates[:self._plate_axis] + np.shape(self._indices) if self._plate_axis != -1: plates = plates + parent_plates[(self._plate_axis+1):] return plates def _compute_plates_multiplier_from_parent(self, index, parent_multiplier): if any(p != 1 for p in parent_multiplier): raise NotImplementedError( "Take node doesn't yet support plate multipliers {0}" .format(parent_multiplier) ) return parent_multiplier