{-# LANGUAGE BangPatterns,FlexibleContexts,DataKinds,TypeFamilies,TypeOperators,Arrows #-} --- Imports --- -- Goal -- import Pendulum import Goal.Core import Goal.Geometry import Goal.Probability import Goal.Simulation -- Qualified -- import qualified System.Directory as D --- Program --- -- Globals -- eps = 0.01 bnd = 2 :: Double cdn = 5 :: Int blkn = 10 :: Int nstpstrn = 10 :: Int nepchs = 100000 :: Int nbns = 10 -- Functions -- trainerMWC nnp0 = do let randomPath nnp = randomState >>= generatePath nstpstrn nnp pathGenerator <- accumulateRandomFunction0 randomPath backpropagator <- boundedStochasticVanillaGradientDescent eps bnd (beliefBackpropagation trnss blkn cdn) nnp0 return . accumulateMealy nnp0 $ proc ((),!nnp) -> do !nnp' <- backpropagator <<< pathGenerator -< nnp returnA -< (nnp',nnp') -- Main -- main :: IO () main = do bl <- D.doesFileExist flnm c0s <- if bl then read <$> readFile flnm else runWithSystemRandom . replicateM (dimension nn) . generate . chart Standard $ fromList Normal [0,0.01] let nnp0 = fromList nn c0s trainer <- runWithSystemRandom $ trainerMWC nnp0 --(lns,lzs,nnps) <- unzip3 <$> streamM trainer printerIO (replicate nepchs ()) let nnp1 = last . take nepchs $ streamChain trainer (mp,mtx1,np,mtx2) = splitNeuralNetwork nnp1 let wgtlyt = coordinateLogHistogram nbns "Network Weights" ["Second Layer Biases", "Second Layer", "First Layer Biases", "First Layer"] [coordinates mp, coordinates mtx1, coordinates np, coordinates mtx2] let qdq0 = fromList (Bundle pndl) [1.5,1.5] vflyt = execEC $ do layout_title .= "Vector Field" vectorFieldLayout plot . fmap plotVectorField . liftEC $ do vectorFieldPlot $ opaque black plot_vectors_title .= "Tru" plot_vectors_mapf .= sliceVectorField scl 0 f qdq0 plot . fmap plotVectorField . liftEC $ do vectorFieldPlot $ opaque blue plot_vectors_mapf .= locationBeliefField trns0 dt scl 0 1 [0,0] nnp1 plot_vectors_title .= "Est" let rnbl = toRenderable $ StackedLayouts [StackedLayout vflyt, StackedLayout wgtlyt] False void $ renderableToAspectWindow False 400 800 rnbl writeFile flnm . show $ listCoordinates nnp1