{-# LANGUAGE TypeOperators, TypeFamilies, FlexibleContexts #-} --- Imports --- -- Goal -- import Goal.Core import Goal.Geometry import Goal.Probability --- Globals --- nsmps = 20 -- True Normal -- sp = chart Standard $ fromList Normal [1.5,2] -- Gradient Ascent -- eps = 0.01 stps = 3000 sp0 = chart Standard $ fromList Normal [0,1] -- Plot -- mnmu = 0 mxmu = 3 mnvr = 1 mxvr = 4 axprms = LinearAxisParams (show . round) 4 4 m1rng = (mnmu,mxmu,600) m2rng = (mnvr,mxvr,600) niso = 20 clrs = rgbaGradient (0,0,0,1) (1,0,0,1) niso -- Functions -- logLikelihood p xs = sum $ log . density p <$> xs naturalDerivatives :: [Double] -> Natural :#: Normal -> Differentials :#: Tangent Natural Normal naturalDerivatives xs p = fromCoordinates (Tangent p) . coordinates $ meanPoint (sufficientStatistic Normal <$> xs) <-> potentialMapping p standardDerivatives :: [Double] -> Standard :#: Normal -> Differentials :#: Tangent Standard Normal standardDerivatives xs p = let [mu,vr] = listCoordinates p in meanPoint [ fromList (Tangent p) [ recip vr * (xi - mu), recip (2*vr) * (recip vr * (xi - mu)^2 - 1) ] | xi <- xs ] -- Layout -- main = do smps <- runWithSystemRandom . replicateM nsmps $ generate sp let mp' = chart Mixture . meanPoint $ sufficientStatistic Normal <$> smps sp' = chart Standard $ transition mp' let vsps1 = take stps $ vanillaGradientAscent eps (standardDerivatives smps) sp0 nsps1 = take stps $ gradientAscent eps (standardDerivatives smps) sp0 let np0 = chart Natural $ transition sp0 vnps2 = take stps $ vanillaGradientAscent eps (naturalDerivatives smps) np0 --nnps2 = take stps $ gradientAscent eps (naturalDerivatives smps) np0 vsps2 = chart Standard . transition <$> vnps2 --nsps2 = chart Standard . transition <$> nnps2 let rnbl = toRenderable . execEC $ do let f x y = logLikelihood (chart Standard $ fromList Normal [x,y]) smps cntrs = contours m1rng m2rng niso f layout_x_axis . laxis_generate .= scaledAxis axprms (mnmu,mxmu) layout_x_axis . laxis_override .= axisGridHide layout_x_axis . laxis_title .= "μ" layout_y_axis . laxis_generate .= scaledAxis axprms (mnvr,mxvr) layout_y_axis . laxis_override .= axisGridHide layout_y_axis . laxis_title .= "σ^2" sequence_ $ do ((_,cntr),clr) <- zip cntrs clrs return . plot . liftEC $ do plot_lines_style .= solidLine 3 clr plot_lines_values .= cntr plot . liftEC $ do plot_lines_style .= solidLine 3 (opaque blue) plot_lines_values .= [toPair <$> vsps2] plot . liftEC $ do plot_lines_style .= solidLine 3 (opaque green) plot_lines_values .= [toPair <$> vsps1] plot . liftEC $ do plot_lines_style .= solidLine 3 (opaque purple) plot_lines_values .= [toPair <$> nsps1] plot . liftEC $ do plot_points_style .= filledCircles 4 (opaque black) plot_points_values .= [toPair sp] plot . liftEC $ do plot_points_style .= filledCircles 4 (opaque red) plot_points_values .= [toPair sp'] --renderableToAspectWindow False 800 600 . toRenderable $ lyt void $ renderableToFile (FileOptions (500,350) PDF) "cross-entropy-descent.pdf" rnbl