Safe Haskell | None |
---|
Base Restricted Boltzmann machine.
- learningRate :: Double
- data Mode
- = Hidden
- | Visible
- | Classification
- data BoltzmannData = BoltzmannData {
- weightsB :: Weights
- classificationWeights :: Weights
- biasB :: Bias
- biasC :: Bias
- biasD :: Bias
- patternsB :: [Pattern]
- hiddenCount :: Int
- pattern_to_class :: [(Pattern, Int)]
- getDimension :: Mode -> Weights -> Int
- buildCBoltzmannData :: MonadRandom m => [Pattern] -> m BoltzmannData
- buildCBoltzmannData' :: MonadRandom m => [Pattern] -> Int -> m BoltzmannData
- getActivationSum :: Weights -> Bias -> Pattern -> Int -> Double
- getActivationProbabilityVisible :: Weights -> Bias -> Pattern -> Int -> Double
- getActivationSumHidden :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> Double
- getHiddenSums :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Vector Double
- getActivationProbabilityHidden :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> Double
- updateNeuronVisible :: MonadRandom m => Weights -> Bias -> Pattern -> Int -> m Int
- updateNeuronHidden :: MonadRandom m => Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> m Int
- updateVisible :: MonadRandom m => Weights -> Bias -> Pattern -> m Pattern
- updateHidden :: MonadRandom m => Weights -> Weights -> Bias -> Pattern -> Pattern -> m Pattern
- updateClassification :: Weights -> Bias -> Pattern -> Pattern
- getClassificationVector :: [(Pattern, Int)] -> Pattern -> Pattern
- oneTrainingStep :: MonadRandom m => BoltzmannData -> Pattern -> m BoltzmannData
- trainBoltzmann :: MonadRandom m => [Pattern] -> Int -> m BoltzmannData
- matchPatternCBoltzmann :: BoltzmannData -> Pattern -> Int
- getFreeEnergy :: BoltzmannData -> Pattern -> Pattern -> Double
- activation :: Double -> Double
- softplus :: Double -> Double
- validClassificationVector :: Pattern -> Int -> Maybe String
- validPattern :: Mode -> Weights -> Pattern -> Maybe String
- validWeights :: Weights -> Maybe String
Documentation
learningRate :: DoubleSource
determines the rate in which the weights are changed in the training phase. http:en.wikipedia.orgwikiRestricted_Boltzmann_machine#Training_algorithm
data BoltzmannData Source
BoltzmannData | |
|
Show BoltzmannData |
getDimension :: Mode -> Weights -> IntSource
Retrieves the dimension of the weights matrix corresponding to the given mode. For hidden, it is the width of the matrix, and for visible it is the height. One has to ensure that the appropriate weights matrix is passed with this function.
buildCBoltzmannData :: MonadRandom m => [Pattern] -> m BoltzmannDataSource
buildCBoltzmannData patterns
trains a boltzmann network with patterns
.
The number of hidden neurons is set to the number of visible neurons.
buildCBoltzmannData' :: MonadRandom m => [Pattern] -> Int -> m BoltzmannDataSource
buildCBoltzmannData' patterns nrHidden
: Takes a list of patterns and
builds a Boltzmann network (by training) in which these patterns are
stable states. The result of this function can be used to run a pattern
against the network, by using matchPatternBoltzmann
.
getActivationSum :: Weights -> Bias -> Pattern -> Int -> DoubleSource
getActivationProbability ws bias pat index
can be used to compute the activation probability for a neuron in the
visible layer, or for parts of the sums requires for
the probability of the classifications
getActivationProbabilityVisible :: Weights -> Bias -> Pattern -> Int -> DoubleSource
getActivationProbabilityVisible ws bias h index
returns the activation
probability for a neuron index
in a visible pattern, given the weights
matrix ws
, the vector of biases bias
. Applies the activation function
to the activation sum, in order to obtain the probability.
getActivationSumHidden :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> DoubleSource
getActivationSumHidden ws bias h index
returns the activation
sum for a neuron index
in a hidden pattern, given the weights
matrix ws
, the vector of biases bias
.
getHiddenSums :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Vector DoubleSource
getActivationSumHidden ws bias h index
returns the activation
sum for all neurons in the hidden pattern, given the weights
matrix ws
, the vector of biases bias
.
getActivationProbabilityHidden :: Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> DoubleSource
getActivationProbabilityVisible ws u bias v index
returns the activation
probability for a neuron index
in a hidden pattern, given the weights
matrices ws
and u
, the vector of biases bias
. Applies the activation function
to the activation sum, in order to obtain the probability.
updateNeuronVisible :: MonadRandom m => Weights -> Bias -> Pattern -> Int -> m IntSource
updateNeuronVisible ws bias h index
updates a neuron in the visible layer by using gibbsSampling, according
to the activation probability
updateNeuronHidden :: MonadRandom m => Weights -> Weights -> Bias -> Pattern -> Pattern -> Int -> m IntSource
Updates a neuron in the hidden layer by using gibbsSampling, according to the activation probability
updateVisible :: MonadRandom m => Weights -> Bias -> Pattern -> m PatternSource
Updates the entire visible layer by using gibbsSampling, according to the activation probability
updateHidden :: MonadRandom m => Weights -> Weights -> Bias -> Pattern -> Pattern -> m PatternSource
Updates the entire visible layer by using gibbsSampling, according to the activation probability
updateClassification :: Weights -> Bias -> Pattern -> PatternSource
Updates a classification vector given the current state of the network ( the u matrix and the vector of biases d, together with a hidden vector h)
getClassificationVector :: [(Pattern, Int)] -> Pattern -> PatternSource
oneTrainingStep :: MonadRandom m => BoltzmannData -> Pattern -> m BoltzmannDataSource
One step which updates the weights in the CD-n training process.
The weights are changed according to one of the training patterns.
http:en.wikipedia.orgwikiRestricted_Boltzmann_machine#Training_algorithm
oneTrainingStep bm visible
updates the parameters of bm
(the 2 weight
matrices and the biases) according to the training instance v
and its classification, obtained by looking in the map kept in bm
trainBoltzmann :: MonadRandom m => [Pattern] -> Int -> m BoltzmannDataSource
The training function for the Boltzmann Machine.
We are using the contrastive divergence algorithm CD-1
TODO see if making the vis
(we could extend to CD-n, but In pratice, CD-1 has been shown to work surprisingly well.
trainBoltzmann pats nrHidden
where pats
are the training patterns
and nrHidden
is the number of neurons to be created in the hidden layer.
http:en.wikipedia.orgwikiRestricted_Boltzmann_machine#Training_algorithm
matchPatternCBoltzmann :: BoltzmannData -> Pattern -> IntSource
matchPatternBoltzmann bm pat
given the Boltzmann trained network bm
recognizes pat
, by classifying it to one of the patterns the network was
trained with. This is done by computing the free energy of pat
with
every possible classification, and choosing the classification with
lowest energy.
http:uai.sis.pitt.edupapers11/p463-louradour.pdf
getFreeEnergy :: BoltzmannData -> Pattern -> Pattern -> DoubleSource
getFreeEnergy bm visible classification_vector
Computes the free energy of v
with classification_vector
, according
to the trained Boltzmann network bm
. It is used for classifying a given
visible vector according to the classes used for training the network bm
.
activation :: Double -> DoubleSource
The activation function for the network (the logistic sigmoid). http:en.wikipedia.orgwikiSigmoid_function
softplus :: Double -> DoubleSource
The function used to compute the free energy http:uai.sis.pitt.edupapers11/p463-louradour.pdf
validClassificationVector :: Pattern -> Int -> Maybe StringSource
validPattern :: Mode -> Weights -> Pattern -> Maybe StringSource
validPattern mode weights pattern
Returns an error string in a Just if the pattern
is not compatible
with weights
and Nothing otherwise. mode
gives the type of the pattern,
which is checked (Visible or Hidden).
validWeights :: Weights -> Maybe StringSource
validWeights ws
checks that a weight matrix is well formed.