{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE StandaloneDeriving #-} -- | Used for Multivariate distributions module HLearn.Models.Distributions.Multivariate.Interface ( Multivariate -- * Type functions -- , Ignore , MultiCategorical (..) , Independent (..) , Dependent (..) -- * Modules , Index (..) , module HLearn.Models.Distributions.Multivariate.Internal.Ignore , module HLearn.Models.Distributions.Multivariate.Internal.Marginalization ) where import Control.DeepSeq import GHC.TypeLits import HLearn.Algebra hiding (Index) import HLearn.Models.Distributions.Common import HLearn.Models.Distributions.Multivariate.Internal.CatContainer hiding (ds,baseparams) import HLearn.Models.Distributions.Multivariate.Internal.Container import HLearn.Models.Distributions.Multivariate.Internal.Ignore import HLearn.Models.Distributions.Multivariate.Internal.Marginalization import HLearn.Models.Distributions.Multivariate.Internal.Unital import HLearn.Models.Distributions.Multivariate.Internal.TypeLens import HLearn.Models.Distributions.Multivariate.MultiNormal ------------------------------------------------------------------------------- -- Multivariate -- | this is the main type for specifying multivariate distributions newtype Multivariate (dp:: *) (xs :: [[* -> * -> *]]) prob = Multivariate (MultivariateTF (Concat xs) prob) type family MultivariateTF (xs::[* -> * -> *]) prob type instance MultivariateTF '[] prob = Unital prob type instance MultivariateTF ((Container univariate sample) ': xs) prob = Container univariate sample (MultivariateTF xs prob) prob type instance MultivariateTF ((MultiContainer dist sample) ': xs) prob = MultiContainer dist sample (MultivariateTF xs prob) prob type instance MultivariateTF ((CatContainer label) ': xs) prob = CatContainer label (MultivariateTF xs prob) prob type instance MultivariateTF ((Ignore' label) ': xs) prob = Ignore' label (MultivariateTF xs prob) prob deriving instance (Read (MultivariateTF (Concat xs) prob)) => Read (Multivariate dp xs prob) deriving instance (Show (MultivariateTF (Concat xs) prob)) => Show (Multivariate dp xs prob) deriving instance (Eq (MultivariateTF (Concat xs) prob)) => Eq (Multivariate dp xs prob) deriving instance (Ord (MultivariateTF (Concat xs) prob)) => Ord (Multivariate dp xs prob) deriving instance (Monoid (MultivariateTF (Concat xs) prob)) => Monoid (Multivariate dp xs prob) deriving instance (Group (MultivariateTF (Concat xs) prob)) => Group (Multivariate dp xs prob) deriving instance (NFData (MultivariateTF (Concat xs) prob)) => NFData (Multivariate dp xs prob) instance ( HomTrainer (MultivariateTF (Concat xs) prob) , Trainable dp , GetHList dp ~ Datapoint (MultivariateTF (Concat xs) prob) ) => HomTrainer (Multivariate dp xs prob) where type Datapoint (Multivariate dp xs prob) = dp train1dp dp = Multivariate $ train1dp $ getHList dp instance Probabilistic (Multivariate dp xs prob) where type Probability (Multivariate dp xs prob) = prob instance ( PDF (MultivariateTF (Concat xs) prob) , Probability (MultivariateTF (Concat xs) prob) ~ prob , Datapoint (MultivariateTF (Concat xs) prob) ~ GetHList dp , Trainable dp , HomTrainer (Multivariate dp xs prob) ) => PDF (Multivariate dp xs prob) where pdf (Multivariate dist) dp = pdf dist (getHList dp) instance ( Marginalize' (Nat1Box n) (MultivariateTF (Concat xs) prob) , MarginalizeOut' (Nat1Box n) (MultivariateTF (Concat xs) prob) ~ MultivariateTF (Concat (Replace2D n xs (Ignore' (Index (HList2TypeList (GetHList dp)) n)))) prob ) => Marginalize' (Nat1Box n) (Multivariate dp xs prob) where type Margin' (Nat1Box n) (Multivariate dp xs prob) = Margin' (Nat1Box n) (MultivariateTF (Concat xs) prob) getMargin' n (Multivariate dist) = getMargin' n dist type MarginalizeOut' (Nat1Box n) (Multivariate dp xs prob) = Multivariate dp (Replace2D n xs (Ignore' (Index (HList2TypeList (GetHList dp)) n))) prob marginalizeOut' n (Multivariate dist) = Multivariate $ marginalizeOut' n dist condition' n (Multivariate dist) dp = Multivariate $ condition' n dist dp type family HList2TypeList hlist :: [a] type instance HList2TypeList (HList xs) = xs type family Index (xs::[a]) (i::Nat1) :: a type instance Index (x ': xs) Zero = x type instance Index (x ': xs) (Succ i) = Index xs i type family Replace2D (n :: Nat1) (xs :: [ [ a ] ]) (newval :: a) :: [ [ a ] ] type instance Replace2D Zero ((x ': xs) ': ys) newval = (newval ': xs) ': ys type instance Replace2D (Succ n) ((x ': xs) ': ys) newval = AppendFront x (Replace2D n (xs ': ys) newval) type instance Replace2D n ('[] ': ys) newval = '[] ': (Replace2D n ys newval) type family AppendFront (x :: a) (xs :: [[a]]) :: [[a]] type instance AppendFront x (xs ': ys) = (x ': xs) ': ys data Boxer xs = Boxer ------------------------------------------------------------------------------- -- Type functions -- type Multivariate (xs::[[* -> * -> *]]) prob = MultivariateTF (Concat xs) prob type family MultiCategorical (xs :: [*]) :: [* -> * -> *] type instance MultiCategorical '[] = ('[]) type instance MultiCategorical (x ': xs) = (CatContainer x) ': (MultiCategorical xs) -- type Dependent dist (xs :: [*]) = '[ MultiContainer (dist xs) xs ] type family Dependent (dist:: * -> [*] -> *) (xs :: [*]) :: [* -> * -> *] type instance Dependent dist xs = '[ MultiContainer dist xs ] type family Independent (dist :: * -> * -> *) (sampleL :: [*]) :: [* -> * -> *] type instance Independent dist '[] = '[] -- type instance Independent (dist :: * -> *) (x ': xs) = (Container dist x) ': (Independent dist xs) type instance Independent (dist :: * -> * -> *) (x ': xs) = (Container dist x) ': (Independent dist xs)