module Hopfield.Hopfield (
Pattern
, Weights
, LearningType (Hebbian, Storkey)
, HopfieldData ()
, weights
, patterns
, buildHopfieldData
, update
, addPatterns
, repeatedUpdate
, updateChain
, matchPattern
, computeH
, energy
) where
import Control.Monad
import Control.Monad.Random (MonadRandom)
import Data.Maybe
import Data.Vector ((!))
import qualified Data.Vector as V
import Data.Vector.Generic.Mutable (write)
import Hopfield.Common
import Hopfield.Util
data LearningType = Hebbian | Storkey deriving (Eq, Show, Read)
data HopfieldData = HopfieldData {
weights :: Weights
, patterns :: [Pattern]
} deriving (Show)
checkWsPat :: (Weights -> Pattern -> a) -> Weights -> Pattern -> a
checkWsPat f ws pat
| Just e <- validWeights ws = error e
| Just e <- validPattern pat = error e
| Just e <- validWeightsPatternSize ws pat = error e
| otherwise = f ws pat
update :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern)
update = checkWsPat update_
repeatedUpdate :: (MonadRandom m) => Weights -> Pattern -> m Pattern
repeatedUpdate = checkWsPat repeatedUpdate_
computeH :: Weights -> Pattern -> Int -> Int
computeH ws pat i = checkWsPat (\w p -> computeH_ w p i) ws pat
energy :: Weights -> Pattern -> Double
energy = checkWsPat energy_
buildHopfieldData :: LearningType -> [Pattern] -> HopfieldData
buildHopfieldData _ [] = error "Train patterns are empty"
buildHopfieldData learningType pats
| first_len == 0
= error "Cannot have empty patterns"
| any (\x -> V.length x /= first_len) pats
= error "All training patterns must have the same length"
| otherwise
= HopfieldData (trainingFunction pats) pats
where
first_len = V.length (head pats)
trainingFunction = case learningType of
Hebbian -> train
Storkey -> trainStorkey
train :: [Pattern] -> Weights
train pats = vector2D ws
where
ws = [ [ w i j ./. n | j <- [0 .. n1] ] | i <- [0 .. n1] ]
w i j
| i == j = 0
| otherwise = sum [ (pat ! i) * (pat ! j) | pat <- pats ]
n = V.length (head pats)
computeH_ :: Weights -> Pattern -> Int -> Int
computeH_ ws pat i = if weighted >= 0 then 1 else 1
where
weighted :: Double
wss = ws ! i
weighted = go 0 0.0
go :: Int -> Double -> Double
go !j !s | j == p = s
| otherwise = let w = wss `V.unsafeIndex` j
x = if pat `V.unsafeIndex` j > 0 then w
else w
in go (j+1) (s+x)
p = V.length pat
update_ :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern)
update_ ws pat = do
randomIndices <- shuffle . toArray $ [0 .. V.length pat 1]
return $ case firstUpdatable (V.fromList randomIndices) of
Nothing -> Nothing
Just index -> Just $ flipAtIndex pat index
where
firstUpdatable indices = go 0
where
go n
| n == V.length pat = Nothing
| pat ! i /= computeH_ ws pat i = Just i
| otherwise = go (n+1)
where i = indices ! n
flipAtIndex vec index = let val = vec ! index
in val `seq` V.modify (\v -> write v index (val)) vec
repeatedUpdate_ :: (MonadRandom m) => Weights -> Pattern -> m Pattern
repeatedUpdate_ ws pat = repeatUntilNothing (update_ ws) pat
matchPattern :: MonadRandom m => HopfieldData -> Pattern -> m (Either Pattern Int)
matchPattern (HopfieldData ws pats) pat = do
converged_pattern <- repeatedUpdate_ ws pat
return $ findInList pats converged_pattern
updateChain :: (MonadRandom m) => HopfieldData -> Pattern -> m [Pattern]
updateChain (HopfieldData ws _pats) pat
| Just e <- validPattern pat = error e
| otherwise = (pat:) `liftM` unfoldrSelfM (update_ ws) pat
addPatterns :: LearningType -> HopfieldData -> [Pattern] -> HopfieldData
addPatterns learning (HopfieldData ws pats) addedPats
| any (isJust . validPattern) addedPats = error "invalid patterns in addMultiplePatterns"
| any (isJust . validWeightsPatternSize ws) addedPats = error "pattern does not match weights in addMultiplePatterns"
| otherwise = HopfieldData new_ws (pats ++ addedPats)
where new_ws = foldl (updateWeightsGivenNewPattern learning) ws addedPats
updateWeightsGivenNewPattern :: LearningType -> Weights -> Pattern -> Weights
updateWeightsGivenNewPattern Storkey ws pat = updateWeightsStorkey ws pat
updateWeightsGivenNewPattern Hebbian ws pat = vector2D updated_ws
where updated_ws = [ [ws ! i ! j + (pat ! i * pat ! j) ./. n | j <- neurons ] | i <- neurons]
n = V.length ws 1
neurons = [0 .. n]
energy_ :: Weights -> Pattern -> Double
energy_ ws pat = s / (2.0)
where
p = V.length pat
w i j = ws ! i ! j
x i = pat ! i
s = sum [ w i j *. (x i * x j) | i <- [0 .. p1], j <- [0 .. p1] ]
validPattern :: Pattern -> Maybe String
validPattern pat = case [ x | x <- V.toList pat, not (x == 1 || x == 1) ] of
[] -> Nothing
x:_ -> Just $ "Pattern contains invalid value " ++ show x
validWeightsPatternSize :: Weights -> Pattern -> Maybe String
validWeightsPatternSize ws pat
| V.length ws /= V.length pat = Just "Pattern size must match network size"
| otherwise = Nothing
validWeights :: Weights -> Maybe String
validWeights ws
| n == 0
= Just "Weight matrix must be non-empty"
| any (\x -> V.length x /= n) $ V.toList ws
= Just "Weight matrix has to be a square matrix"
| any (/= 0) [ ws ! i ! i | i <- [0..n1] ]
= Just "Weight matrix first diagonal must be zero"
| not $ and [ abs( (ws ! i ! j) (ws ! j ! i) ) < 0.0001 | i <- [0..n1], j <- [0..n1] ]
= Just "Weight matrix must be symmetric"
| null [ abs (ws ! i ! j) > 1 | i <- [0..n1], j <- [0..n1] ]
= Just "Weights should be between (-1, 1)"
| otherwise = Nothing
where
n = V.length ws
storkeyHiddenSum :: Weights -> Pattern -> Int -> Int -> Double
storkeyHiddenSum ws pat i j
= sum [ ws ! i ! k *. (pat ! k) | k <- [0 .. n 1] , k /= i , k /= j]
where n = V.length ws
updateWeightsGivenIndicesStorkey :: Weights -> Pattern -> Int -> Int -> Double
updateWeightsGivenIndicesStorkey ws pat i j
| i == j = 0.0
| otherwise = ws ! i ! j + (1 :: Int) ./. n * (fromIntegral (pat ! i * (pat ! j)) h j i *. (pat ! i) h i j *. (pat ! j))
where n = V.length ws
h = storkeyHiddenSum ws pat
updateWeightsStorkey :: Weights -> Pattern -> Weights
updateWeightsStorkey ws pat
= vector2D [ [ updateWeightsGivenIndicesStorkey ws pat i j | j <- [0 ..n 1] ] | i <- [0 ..n 1] ]
where n = V.length ws
trainStorkey :: [Pattern] -> Weights
trainStorkey pats = foldl updateWeightsStorkey start_ws pats
where start_ws = vector2D $ replicate n $ replicate n 0
n = V.length $ head pats