module HLearn.Models.Distributions.Multivariate.Internal.Container
( Container
, MultiContainer
)
where
import Control.DeepSeq
import Debug.Trace
import GHC.TypeLits
import HLearn.Algebra
import HLearn.Models.Distributions.Common
import HLearn.Models.Distributions.Multivariate.Internal.Ignore
import HLearn.Models.Distributions.Multivariate.Internal.Marginalization
data Container (dist :: * -> a -> *) (sample:: a) basedist (prob:: * ) = Container
{ dist :: dist prob sample
, basedist :: basedist
}
deriving (Read,Show,Eq,Ord)
instance (NFData (dist prob sample), NFData basedist) => NFData (Container dist sample basedist prob) where
rnf c = deepseq (dist c) $ rnf (basedist c)
newtype MultiContainer dist sample basedist prob = MultiContainer (Container dist sample basedist prob)
deriving (Read,Show,Eq,Ord,Monoid,Abelian,Group,NFData)
instance (Abelian (dist prob sample), Abelian basedist) => Abelian (Container dist sample basedist prob)
instance
( Monoid (dist prob sample)
, Monoid basedist
) => Monoid (Container dist sample basedist prob)
where
mempty = Container mempty mempty
c1 `mappend` c2 = Container
{ dist = dist c1 <> dist c2
, basedist = basedist c1 <> basedist c2
}
instance
( Group (dist prob sample)
, Group basedist
) => Group (Container dist sample basedist prob)
where
inverse c = Container
{ dist = inverse $ dist c
, basedist = inverse $ basedist c
}
instance
( HasRing (dist prob sample)
, HasRing basedist
, Ring (dist prob sample) ~ Ring basedist
) => HasRing (Container dist sample basedist prob)
where
type Ring (Container dist sample basedist prob) = Ring (dist prob sample)
instance
( HasRing (dist prob sample)
, HasRing basedist
, Ring (dist prob sample) ~ Ring basedist
) => HasRing (MultiContainer dist sample basedist prob)
where
type Ring (MultiContainer dist sample basedist prob) = Ring (dist prob sample)
instance
( Module (dist prob sample)
, Module basedist
, Ring (dist prob sample) ~ Ring basedist
) => Module (Container dist sample basedist prob)
where
r .* c = Container
{ dist = r .* (dist c)
, basedist = r .* (basedist c)
}
deriving instance
( Module (dist prob sample)
, Module basedist
, Ring (dist prob sample) ~ Ring basedist
) => Module (MultiContainer dist sample basedist prob)
instance
( HomTrainer (dist prob sample)
, HomTrainer basedist
, Datapoint basedist ~ HList ys
) => HomTrainer (Container dist sample basedist prob)
where
type Datapoint (Container dist sample basedist prob) =
(Datapoint (dist prob sample)) `HCons` (Datapoint basedist)
train1dp (dp:::basedp) = Container
{ dist = train1dp dp
, basedist = train1dp basedp
}
instance
( NumDP (dist prob sample)
, HasRing basedist
, Ring basedist ~ Ring (dist prob sample)
) => NumDP (Container dist sample basedist prob)
where
numdp container = numdp $ dist container
instance
( HomTrainer (dist prob sample)
, HomTrainer basedist
, Datapoint (dist prob sample) ~ HList zs
, Datapoint basedist ~ HList ys
, HTake1 (Nat1Box (Length1 zs)) (HList (zs++ys)) (HList zs)
, HDrop1 (Nat1Box (Length1 zs)) (HList (zs++ys)) (HList ys)
) => HomTrainer (MultiContainer dist sample basedist prob)
where
type Datapoint (MultiContainer dist sample basedist prob) =
(Datapoint (dist prob sample)) `HAppend` (Datapoint basedist)
train1dp dpL = MultiContainer $ Container
{ dist = train1dp $ htake1 (Nat1Box :: Nat1Box (Length1 zs)) dpL
, basedist = train1dp $ hdrop1 (Nat1Box :: Nat1Box (Length1 zs)) dpL
}
instance
( NumDP (dist prob sample)
, HasRing basedist
, Ring basedist ~ Ring (dist prob sample)
) => NumDP (MultiContainer dist sample basedist prob)
where
numdp (MultiContainer container) = numdp $ dist container
instance Probabilistic (Container dist sample basedist prob) where
type Probability (Container dist sample basedist prob) = prob
instance
( PDF (dist prob sample)
, PDF basedist
, Probability (dist prob sample) ~ prob
, Probability basedist ~ prob
, Probabilistic (Container dist sample basedist prob)
, Datapoint basedist ~ HList ys
, Datapoint (dist prob sample) ~ y
, Datapoint (Container dist sample basedist prob) ~ HList (y ': ys)
, Num prob
) => PDF (Container dist sample basedist prob)
where
pdf container (dp:::basedp) = pdf1*pdf2
where
pdf1 = pdf (dist container) dp
pdf2 = pdf (basedist container) basedp
instance Marginalize' (Nat1Box Zero) (Container dist (sample :: *) basedist prob) where
type Margin' (Nat1Box Zero) (Container dist sample basedist prob) = dist prob sample
getMargin' _ container = dist container
type MarginalizeOut' (Nat1Box Zero) (Container dist sample basedist prob) = Ignore' sample basedist prob
marginalizeOut' _ container = Ignore' $ basedist container
condition' _ container dp = Ignore' $ basedist container
instance
( Marginalize' (Nat1Box n) basedist
) => Marginalize' (Nat1Box (Succ n)) (Container dist sample basedist prob)
where
type Margin' (Nat1Box (Succ n)) (Container dist sample basedist prob) = Margin' (Nat1Box n) basedist
getMargin' _ container = getMargin' (undefined :: Nat1Box n) $ basedist container
type MarginalizeOut' (Nat1Box (Succ n)) (Container dist sample basedist prob) =
Container dist sample (MarginalizeOut' (Nat1Box n) basedist) prob
marginalizeOut' _ container = Container
{ dist = dist container
, basedist = marginalizeOut' (undefined :: Nat1Box n) $ basedist container
}
condition' _ container dp = Container
{ dist = dist container
, basedist = condition' (undefined :: Nat1Box n) (basedist container) dp
}
instance Probabilistic (MultiContainer dist sample basedist prob) where
type Probability (MultiContainer dist sample basedist prob) = prob
instance
( PDF (dist prob sample)
, PDF basedist
, prob ~ Probability (dist prob sample)
, prob ~ Probability basedist
, Num prob
, Datapoint (dist prob sample) ~ HList dpL
, Datapoint basedist ~ HList basedpL
, HTake1 (Nat1Box (Length1 dpL)) (HList (dpL ++ basedpL)) (HList dpL)
, HDrop1 (Nat1Box (Length1 dpL)) (HList (dpL ++ basedpL)) (HList basedpL)
) => PDF (MultiContainer dist sample basedist prob)
where
pdf (MultiContainer container) dp = (pdf (dist container) dp1)*(pdf (basedist container) dp2)
where
dp1 = htake1 (Nat1Box :: Nat1Box (Length1 dpL)) dp
dp2 = hdrop1 (Nat1Box :: Nat1Box (Length1 dpL)) dp