-- |
-- Module      :  ELynx.Tree.Simulate.Coalescent
-- Description :  Generate coalescent trees
-- Copyright   :  (c) Dominik Schrempf 2021
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
--
-- Creation date: Wed May 16 13:13:11 2018.
module ELynx.Tree.Simulate.Coalescent
  ( simulate,
  )
where

import Control.Monad.Primitive
import ELynx.Tree.Distribution.CoalescentContinuous
import ELynx.Tree.Length
import ELynx.Tree.Rooted
import Statistics.Distribution
import System.Random.MWC

-- | Simulate a coalescent tree with @n@ leaves. The branch lengths are in units
-- of effective population size.
simulate ::
  (PrimMonad m) =>
  -- | Number of leaves.
  Int ->
  Gen (PrimState m) ->
  m (Tree Length Int)
simulate :: Int -> Gen (PrimState m) -> m (Tree Length Int)
simulate Int
n = Int
-> Int
-> Forest Length Int
-> Gen (PrimState m)
-> m (Tree Length Int)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Int
-> Forest Length Int
-> Gen (PrimState m)
-> m (Tree Length Int)
simulate' Int
n Int
0 Forest Length Int
trs
  where
    trs :: Forest Length Int
trs = [Length -> Int -> Forest Length Int -> Tree Length Int
forall e a. e -> a -> Forest e a -> Tree e a
Node Length
0 Int
i [] | Int
i <- [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

simulate' ::
  (PrimMonad m) =>
  Int ->
  Int ->
  Forest Length Int ->
  Gen (PrimState m) ->
  m (Tree Length Int)
simulate' :: Int
-> Int
-> Forest Length Int
-> Gen (PrimState m)
-> m (Tree Length Int)
simulate' Int
n Int
a Forest Length Int
trs Gen (PrimState m)
g
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> m (Tree Length Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot construct trees without leaves."
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& Forest Length Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest Length Int
trs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 = [Char] -> m (Tree Length Int)
forall a. HasCallStack => [Char] -> a
error [Char]
"Too many trees provided."
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& Forest Length Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Forest Length Int
trs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Tree Length Int -> m (Tree Length Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Tree Length Int -> m (Tree Length Int))
-> Tree Length Int -> m (Tree Length Int)
forall a b. (a -> b) -> a -> b
$ Forest Length Int -> Tree Length Int
forall a. [a] -> a
head Forest Length Int
trs
  | Bool
otherwise = do
    -- Indices of the leaves to join will be i-1 and i.
    Int
i <- (Int, Int) -> Gen (PrimState m) -> m Int
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (Int
1, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Gen (PrimState m)
g
    -- The time of the coalescent event.
    Length
t <- Double -> Length
toLengthUnsafe (Double -> Length) -> m Double -> m Length
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExponentialDistribution -> Gen (PrimState m) -> m Double
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m Double
genContVar (Int -> ExponentialDistribution
coalescentDistributionCont Int
n) Gen (PrimState m)
g
    let trs' :: Forest Length Int
trs' = (Tree Length Int -> Tree Length Int)
-> Forest Length Int -> Forest Length Int
forall a b. (a -> b) -> [a] -> [b]
map ((Length -> Length) -> Tree Length Int -> Tree Length Int
forall e a. (e -> e) -> Tree e a -> Tree e a
modifyStem (Length -> Length -> Length
forall a. Num a => a -> a -> a
+ Length
t)) Forest Length Int
trs -- Move time 't' up on the tree.
        tl :: Tree Length Int
tl = Forest Length Int
trs' Forest Length Int -> Int -> Tree Length Int
forall a. [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        tr :: Tree Length Int
tr = Forest Length Int
trs' Forest Length Int -> Int -> Tree Length Int
forall a. [a] -> Int -> a
!! Int
i
        -- Join the two chosen trees.
        tm :: Tree Length Int
tm = Length -> Int -> Forest Length Int -> Tree Length Int
forall e a. e -> a -> Forest e a -> Tree e a
Node Length
0 Int
a [Tree Length Int
tl, Tree Length Int
tr]
        -- Take the trees on the left, the merged tree, and the trees on the right.
        trs'' :: Forest Length Int
trs'' = Int -> Forest Length Int -> Forest Length Int
forall a. Int -> [a] -> [a]
take (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Forest Length Int
trs' Forest Length Int -> Forest Length Int -> Forest Length Int
forall a. [a] -> [a] -> [a]
++ [Tree Length Int
tm] Forest Length Int -> Forest Length Int -> Forest Length Int
forall a. [a] -> [a] -> [a]
++ Int -> Forest Length Int -> Forest Length Int
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Forest Length Int
trs'
    Int
-> Int
-> Forest Length Int
-> Gen (PrimState m)
-> m (Tree Length Int)
forall (m :: * -> *).
PrimMonad m =>
Int
-> Int
-> Forest Length Int
-> Gen (PrimState m)
-> m (Tree Length Int)
simulate' (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
a Forest Length Int
trs'' Gen (PrimState m)
g