module Data.Probability where

import Definitive

newtype ProbT t m a = ProbT (WriterT (Product t) (ListT m) a)
                    deriving (Unit,Functor,Applicative
                             ,Semigroup,Monoid
                             ,MonadFix,MonadWriter (Product t))
instance (Ring t,Monad m) => Monad (ProbT t m) where join = coerceJoin ProbT
type Prob t a = ProbT t Id a

i'ProbT :: Iso (ProbT t m a) (ProbT t' m' a') (WriterT (Product t) (ListT m) a) (WriterT (Product t') (ListT m') a')
i'ProbT = iso ProbT (\(ProbT p) -> p)
probT :: (Functor m,Functor m') => Iso (ProbT t m a) (ProbT t' m' a') (m [(t,a)]) (m' [(t',a')])
probT = listT.mapping (i'pair i'_ id).writerT.i'ProbT
prob :: Iso (Prob t a) (Prob t' a') [(t,a)] [(t',a')]
prob = i'Id.probT

c'prob :: Constraint t -> Constraint (Prob t a)
c'prob _ = c'_

instance (Monad m,Invertible t) => MonadList (ProbT t m) where
  fork l = pure [(x,a) | a <- l]^.probT
    where x = recip (size l)

sample :: (Eq a,Monoid t) => a -> Prob t a -> (t,t)
sample x p = foldMap (\(t,y) -> (if x==y then t else zero,t)) (p^..prob)