Source code for bayespy.inference.vmp.nodes.ml

################################################################################
# Copyright (C) 2016 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################

import numpy as np

from .node import Moments
from .deterministic import Deterministic
from .stochastic import Stochastic


class DeltaMoments(Moments):
    r"""
    Class for the moments of constants or delta distributed variables
    """


    def __init__(self, shape):
        self.shape = shape
        self.dims = (shape,)
        return super().__init__()


    @classmethod
    def from_values(cls, x, ndim):
        if np.ndim(x) < ndim:
            raise ValueError("Broadcasting not (yet) supported in DeltaMoments")
        if ndim == 0:
            return cls(())
        else:
            return cls(np.shape(x)[-ndim:])


    def get_converter(self, moments_to):
        if issubclass(DeltaMoments, moments_to):
            return lambda x: x
        return get_delta_moments_class_converter(moments_to)


    def compute_fixed_moments(self, x):
        r"""
        Compute the moments for a fixed value
        """
        return [x]


    def compute_dims_from_values(self, x):
        r"""
        Return the shape of the moments for a fixed value.
        """
        return ((),)


    def get_instance_conversion_kwargs(self):
        return dict(shape=self.shape)


    def get_instance_converter(self, shape):
        if shape != self.shape:
            raise ValueError()
        return None


class DeltaClassConverterMoments(Moments):


    def __init__(self, x, moments_class):
        self.x = x
        self.moments_class = moments_class
        return


    def get_instance_conversion_kwargs(self):
        return dict(i_am_delta=True)


    def get_instance_converter(self, **kwargs):
        if kwargs.get('i_am_delta'):
            return None
        moments = self.moments_class.from_values(
            self.x.get_moments()[0],
            **kwargs
        )
        return DeltaInstanceConverter(moments)


def get_delta_moments_class_converter(moments_class):


    class DeltaClassConverter(Deterministic):


        def __init__(self, node):
            self._parent_moments = (node._moments,)
            self._moments = DeltaClassConverterMoments(node, moments_class)
            return super().__init__(node, dims=((),))


        def _compute_moments(self, u):
            return u


        def _compute_message_to_parent(self, index, m, u):
            return m


    return DeltaClassConverter


class DeltaInstanceConverter():


    def __init__(self, moments):
        self.moments = moments
        return


    def compute_moments(self, u):
        return self.moments.compute_fixed_moments(u[0])


    def compute_message_to_parent(self, m, u_parent):
        x = u_parent[0]
        (u, du) = self.moments.compute_fixed_moments(x, gradient=m)
        return [du]


    def compute_weights_to_parent(self, weights):
        return 1


    def plates_multiplier_from_parent(self, plates_multiplier):
        return ()


    def plates_from_parent(self, plates):
        return self.moments.plates_from_shape(plates)


    def plates_to_parent(self, plates):
        return self.moments.shape_from_plates(plates)


class MaximumLikelihood(Stochastic):


    _parent_moments = ()


    def __init__(self, array, regularization=None, **kwargs):
        self._x = array
        self._moments = DeltaMoments(np.shape(array))
        self._regularization = regularization
        return super().__init__(
            plates=np.shape(array),
            dims=( (), ),
            initialize=False,
            **kwargs
        )


    def _get_id_list(self):
        return []


    def get_moments(self):
        return [self._x]


    def lower_bound_contribution(self, ignore_masked=None):
        if self._regularization is None:
            return 0

        return -np.sum(self._regularization(self._x))


    def get_riemannian_gradient(self):
        m_children = self._message_from_children(u_self=self.get_moments())
        g = m_children
        # TODO/FIXME: REGULARIZATION GRADIENT!!
        return g


    def get_gradient(self, rg):
        return rg


    def get_parameters(self):
        return [self._x]


    def set_parameters(self, x):
        if len(x) != 1:
            raise Exception("Wrong number of parameters. Should be 1, is {0}".format(len(x)))
        self._x = x[0]
        return


    def _update_distribution_and_lowerbound(self, m):
        raise NotImplementedError()


[docs]class Function(Deterministic):
[docs] def __init__(self, function, *nodes_gradients, shape=None, **kwargs): self._function = function (nodes, gradients) = zip(*nodes_gradients) self._parent_moments = tuple(node._moments for node in nodes) self._gradients = gradients if shape is None: # Shape wasn't given explicitly. Computes the output value once to # determine the shape. y = self._compute_moments( *[ node.get_moments() for node in nodes ] ) shape = np.shape(y[0]) self._moments = DeltaMoments(shape) return super().__init__(*nodes, dims=((),), **kwargs)
def _compute_moments(self, *u_nodes): x = [u[0] for u in u_nodes] return [self._function(*x)] def _compute_message_to_parent(self, index, m, *u_nodes): x = [u[0] for u in u_nodes] return [self._gradients[index](m[0], *x)] def _compute_weights_to_parent(self, index, mask): return 1 def _compute_plates_from_parent(self, index, plates): return self._moments.shape def _compute_plates_to_parent(self, index, plates): return self.parents[index].plates