{-# LANGUAGE DataKinds
           , TypeOperators
           , NoImplicitPrelude
           , FlexibleContexts
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.04.21
-- |
-- Module      :  Language.Hakaru.Inference
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- TODO: we may want to give these longer\/more-explicit names so
-- as to be a bit less ambiguous in the larger Haskell ecosystem.
----------------------------------------------------------------
module Language.Hakaru.Inference
    ( priorAsProposal
    , mh
    , mcmc
    , gibbsProposal
    , slice
    , sliceX
    , incompleteBeta
    , regBeta
    , tCDF
    , approxMh
    , kl
    ) where

import Prelude (($), (.), error, Maybe(..), return)
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.Syntax.AST (Term)
import Language.Hakaru.Syntax.ABT (ABT, binder)
import Language.Hakaru.Syntax.Prelude
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Expect (expect, normalize)
import Language.Hakaru.Disintegrate (determine, density, disintegrate)

import qualified Data.Text as Text

----------------------------------------------------------------
----------------------------------------------------------------
priorAsProposal
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] ('HMeasure (HPair a b))
    -> abt '[] (HPair a b)
    -> abt '[] ('HMeasure (HPair a b))
priorAsProposal p x =
    bern (prob_ 0.5) >>= \c ->
    p >>= \x' ->
    dirac $
        if_ c
            (pair (fst x ) (snd x'))
            (pair (fst x') (snd x ))


-- We don't do the accept\/reject part of MCMC here, because @min@
-- and @bern@ don't do well in @simplify@! So we'll be passing the
-- resulting AST of 'mh' to 'simplify' before plugging that into
-- @mcmc@; that's why 'easierRoadmapProg4' and 'easierRoadmapProg4''
-- have different types.
--
-- TODO: the @a@ type should be pure (aka @a ~ Expect' a@ in the old parlance).
-- BUG: get rid of the SingI requirements due to using 'lam'
mh  :: (ABT Term abt)
    => abt '[] (a ':-> 'HMeasure a)
    -> abt '[] ('HMeasure a)
    -> abt '[] (a ':-> 'HMeasure (HPair a 'HProb))
mh proposal target =
    case determine $ density target of
    Nothing -> error "mh: couldn't get density"
    Just theDensity ->
        let_ theDensity $ \mu ->
        lam' $ \old ->
            app proposal old >>= \new ->
            dirac $ pair' new (mu `app` {-pair-} new {-old-} / mu `app` {-pair-} old {-new-})
  where lam' f = lamWithVar Text.empty (sUnMeasure $ typeOf target) f
        pair'  = pair_ (sUnMeasure $ typeOf target) SProb

-- BUG: get rid of the SingI requirements due to using 'lam' in 'mh'
mcmc :: (ABT Term abt)
    => abt '[] (a ':-> 'HMeasure a)
    -> abt '[] ('HMeasure a)
    -> abt '[] (a ':-> 'HMeasure a)
mcmc proposal target =
    let_ (mh proposal target) $ \f ->
    lamWithVar Text.empty (sUnMeasure $ typeOf target) $ \old ->
        app f old >>= \new_ratio ->
        new_ratio `unpair` \new ratio ->
        bern (min (prob_ 1) ratio) >>= \accept ->
        dirac (if_ accept new old)


gibbsProposal
    :: (ABT Term abt, SingI a, SingI b)
    => abt '[] ('HMeasure (HPair a b))
    -> abt '[] (HPair a b)
    -> abt '[] ('HMeasure (HPair a b))
gibbsProposal p xy =
    case determine $ disintegrate p of
    Nothing -> error "gibbsProposal: couldn't disintegrate"
    Just q ->
        xy `unpair` \x _y ->
        pair x <$> normalize (q `app` x)


-- Slice sampling can be thought of:
--
-- slice target x = do
--      u  <- uniform(0, density(target, x))
--      x' <- lebesgue
--      condition (density(target, x') >= u) true
--      return x'

slice
    :: (ABT Term abt)
    => abt '[] ('HMeasure 'HReal)
    -> abt '[] ('HReal ':-> 'HMeasure 'HReal)
slice target =
    case determine $ density target of
    Nothing -> error "slice: couldn't get density"
    Just densAt ->
        lam $ \x ->
        uniform (real_ 0) (fromProb $ app densAt x) >>= \u ->
        normalize $
        lebesgue >>= \x' ->
        withGuard (u < (fromProb $ app densAt x')) $
        dirac x'


sliceX
    :: (ABT Term abt, SingI a)
    => abt '[] ('HMeasure a)
    -> abt '[] ('HMeasure (HPair a 'HReal))
sliceX target =
    case determine $ density target of
    Nothing -> error "sliceX: couldn't get density"
    Just densAt ->
        target `bindx` \x ->
        uniform (real_ 0) (fromProb $ app densAt x)


incompleteBeta
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
incompleteBeta x a b =
    let one' = real_ 1 in
    integrate (real_ 0) (fromProb x) $ \t ->
        unsafeProb t ** (fromProb a - one')
        * unsafeProb (one' - t) ** (fromProb b - one')


regBeta -- TODO: rename 'regularBeta'
    :: (ABT Term abt)
    => abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
    -> abt '[] 'HProb
regBeta x a b = incompleteBeta x a b / betaFunc a b


tCDF :: (ABT Term abt) => abt '[] 'HReal -> abt '[] 'HProb -> abt '[] 'HProb
tCDF x v =
    let b = regBeta (v / (unsafeProb (x*x) + v)) (v / prob_ 2) (prob_ 0.5)
    in  unsafeProb $ real_ 1 - real_ 0.5 * fromProb b


-- BUG: get rid of the SingI requirements due to using 'lam'
approxMh
    :: (ABT Term abt, SingI a)
    => (abt '[] a -> abt '[] ('HMeasure a))
    -> abt '[] ('HMeasure a)
    -> [abt '[] a -> abt '[] ('HMeasure a)]
    -> abt '[] (a ':-> 'HMeasure a)
approxMh _ _ [] = error "TODO: approxMh for empty list"
approxMh proposal prior (_:xs) =
    case determine . density $ bindx prior proposal of
    Nothing -> error "approxMh: couldn't get density"
    Just theDensity ->
        lam $ \old ->
        let_ theDensity $ \mu ->
        unsafeProb <$> uniform (real_ 0) (real_ 1) >>= \u ->
        proposal old >>= \new ->
        let_ (u * mu `app` pair new old / mu `app` pair old new) $ \u0 ->
        let_ (l new new / l old old) $ \l0 ->
        let_ (tCDF (n - real_ 1) (udif l0 u0)) $ \delta ->
        if_ (delta < eps)
            (if_ (u0 < l0)
                (dirac new)
                (dirac old))
            (approxMh proposal prior xs `app` old)
    where
    n   = real_ 2000
    eps = prob_ 0.05
    udif lo hi = unsafeProb $ fromProb lo - fromProb hi
    l = \_d1 _d2 -> prob_ 2 -- determine (density (\theta -> x theta))

kl :: (ABT Term abt)
   => abt '[] ('HMeasure a)
   -> abt '[] ('HMeasure a)
   -> Maybe (abt '[] 'HProb)
kl p q = do
  dp <- determine $ density p
  dq <- determine $ density q
  return
    . expect p
    . binder Text.empty (sUnMeasure $ typeOf p)
    $ \i -> unsafeProb $ log (app dp i / app dq i)