Source code for bayespy.inference.vmp.nodes.concat_gaussian
import numpy as np
from bayespy.utils import misc
from bayespy.utils import linalg
from .gaussian import GaussianMoments
from .deterministic import Deterministic
[docs]class ConcatGaussian(Deterministic):
"""Concatenate Gaussian vectors along the variable axis (not plate axis)
NOTE: This concatenates on the variable axis! That is, the dimensionality
of the resulting Gaussian vector is the sum of the dimensionalities of the
input Gaussian vectors.
TODO: Add support for Gaussian arrays and arbitrary concatenation axis.
"""
[docs] def __init__(self, *nodes, **kwargs):
# Number of nodes to concatenate
N = len(nodes)
# This is stuff that will be useful when implementing arbitrary
# concatenation. That is, first determine ndim.
#
# # Convert nodes to Gaussians (if they are not nodes, don't worry)
# nodes_gaussian = []
# for node in nodes:
# try:
# node_gaussian = node._convert(GaussianMoments)
# except AttributeError: # Moments.NoConverterError:
# nodes_gaussian.append(node)
# else:
# nodes_gaussian.append(node_gaussian)
# nodes = nodes_gaussian
#
# # Determine shape from the first Gaussian node
# shape = None
# for node in nodes:
# try:
# shape = node.dims[0]
# except AttibuteError:
# pass
# else:
# break
# if shape is None:
# raise ValueError("Couldn't determine shape from the input nodes")
#
# ndim = len(shape)
nodes = [self._ensure_moments(node, GaussianMoments, ndim=1)
for node in nodes]
D = sum(node.dims[0][0] for node in nodes)
shape = (D,)
self._moments = GaussianMoments(shape)
self._parent_moments = [node._moments for node in nodes]
# Make sure all parents are Gaussian vectors
if any(len(node.dims[0]) != 1 for node in nodes):
raise ValueError("Input nodes must be (Gaussian) vectors")
self.slices = tuple(np.cumsum([0] + [node.dims[0][0] for node in nodes]))
D = self.slices[-1]
return super().__init__(*nodes, dims=((D,), (D, D)), **kwargs)
def _compute_moments(self, *u_nodes):
x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1)
xx = misc.block_diag(*[u[1] for u in u_nodes])
# Explicitly broadcast xx to plates of x
x_plates = np.shape(x)[:-1]
xx = np.ones(x_plates)[...,None,None] * xx
# Compute the cross-covariance terms using the means of each variable
# (because covariances are zero for factorized nodes in the VB
# approximation)
i_start = 0
for m in range(len(u_nodes)):
i_end = i_start + np.shape(u_nodes[m][0])[-1]
j_start = 0
for n in range(m):
j_end = j_start + np.shape(u_nodes[n][0])[-1]
xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1)
xx[...,i_start:i_end,j_start:j_end] = xm_xn
xx[...,j_start:j_end,i_start:i_end] = misc.T(xm_xn)
j_start = j_end
i_start = i_end
return [x, xx]
def _compute_message_to_parent(self, i, m, *u_nodes):
r = self.slices
# Pick the proper parts from the message array
m0 = m[0][...,r[i]:r[i+1]]
m1 = m[1][...,r[i]:r[i+1],r[i]:r[i+1]]
# Handle cross-covariance terms by using the mean of the covariate node
for (j, u) in enumerate(u_nodes):
if j != i:
m0 = m0 + 2 * np.einsum(
'...ij,...j->...i',
m[1][...,r[i]:r[i+1],r[j]:r[j+1]],
u[0]
)
return [m0, m1]