{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}

-- |
-- Module      :  ELynx.Data.Tree.Phylogeny
-- Description :  Phylogenetic trees
-- Copyright   :  (c) Dominik Schrempf 2020
-- License     :  GPL-3.0-or-later
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  portable
-- Creation date: Thu Jan 17 16:08:54 2019.
-- A phylogeny is a 'Tree' with unique leaf labels, and the order of the trees
-- in the sub-forest is considered to be meaningless.
-- Internally, however, the underlying 'Tree' data structure stores the
-- sub-forest as a list, which has a specific order. Hence, we have to do some
-- tricks when comparing trees, and tree comparison is slow.
-- Also, the uniqueness of the leaves is not ensured by the data type, but has
-- to be checked at runtime. Functions relying on the tree to have unique leaves
-- do perform this check, and return 'Left' with an error message, if the tree
-- has duplicate leaves.
-- Note: 'Tree's are rooted.
-- Note: 'Tree's encoded in Newick format correspond to rooted trees. By
-- convention only, a tree parsed from Newick format is usually thought to be
-- unrooted, when the root node is multifurcating and has three children. This
-- convention is not enforced here. Newick trees are just parsed as they are,
-- and a rooted tree is returned.
-- The bifurcating root of a tree can be changed with 'roots', or 'rootAt'.
-- Trees with multifurcating root nodes can be properly rooted using 'outgroup'.
module ELynx.Data.Tree.Phylogeny
  ( -- * Functions

    -- * Branch labels
    Phylo (..),
    Length (..),
    Support (..),
    PhyloStrict (..),

import Control.DeepSeq
import Data.Aeson
import Data.Bifoldable
import Data.Bifunctor
import Data.Bitraversable
import Data.List hiding (intersect)
import Data.Maybe
import Data.Monoid
import Data.Semigroup
import Data.Set (Set)
import qualified Data.Set as S
import ELynx.Data.Tree.Bipartition
import ELynx.Data.Tree.Measurable
import ELynx.Data.Tree.Rooted
import ELynx.Data.Tree.Splittable
import ELynx.Data.Tree.Supported
import GHC.Generics

-- | The equality check is slow because the order of children is considered to
-- be arbitrary.
equal :: (Eq e, Eq a) => Tree e a -> Tree e a -> Bool
equal ~(Node brL lbL tsL) ~(Node brR lbR tsR) =
  (brL == brR)
    && (lbL == lbR)
    && (length tsL == length tsR)
    && all (`elem` tsR) tsL

-- | Compute the intersection of trees.
-- The intersections are the largest subtrees sharing the same leaf set.
-- Degree two nodes are pruned with 'prune'.
-- Return 'Left' if:
-- - the intersection of leaves is empty.
intersect ::
  (Semigroup e, Eq e, Ord a) => Forest e a -> Either String (Forest e a)
intersect ts
  | S.null lvsCommon = Left "intersect: Intersection of leaves is empty."
  | otherwise = case sequence [dropLeavesWith (predicate ls) t | (ls, t) <- zip leavesToDrop ts] of
    Nothing -> Left "intersect: A tree is empty."
    Just ts' -> Right ts'
    -- Leaf sets.
    lvss = map (S.fromList . leaves) ts
    -- Common leaf set.
    lvsCommon = foldl1' S.intersection lvss
    -- Leaves to drop for each tree in the forest.
    leavesToDrop = map (S.\\ lvsCommon) lvss
    -- Predicate.
    predicate lvsToDr l = l `S.member` lvsToDr

-- | Check if a tree is bifurcating.
-- A Bifurcating tree only contains degree one (leaves) and degree three nodes
-- (internal bifurcating nodes).
bifurcating :: Tree e a -> Bool
bifurcating (Node _ _ []) = True
bifurcating (Node _ _ [x, y]) = bifurcating x && bifurcating y
bifurcating _ = False

-- TODO. Adapt for trees with branches.
-- -- | Remove multifurcations.
-- --
-- -- A caterpillar like bifurcating tree is used to resolve all multifurcations on
-- -- a tree. The multifurcating nodes are copied.
-- --
-- -- Branch labels are not handled.
-- resolve :: Tree e a -> Tree e a
-- resolve t@(Node _ _ []) = t
-- resolve (Node _ l [x]) = Node () l [resolve x]
-- resolve (Node _ l [x, y]) = Node () l $ map resolve [x, y]
-- resolve (Node _ l (x : xs)) = Node () l $ map resolve [x, Node () l xs]

-- | Resolve a multifurcation at the root using an outgroup.
-- A bifurcating root node with the provided label is introduced. The affected
-- branch is 'split'.
-- Note, the degree of the former root node is decreased by one.
-- If the root node is bifurcating, use 'rootAt'.
-- Return 'Left' if
-- - the tree has duplicate leaves;
-- - the root node is not multifurcating;
-- - the provided outgroup is not found on the tree or is polyphyletic.
outgroup :: (Semigroup e, Splittable e, Ord a) => Set a -> a -> Tree e a -> Either String (Tree e a)
outgroup _ _ (Node _ _ []) = Left "outgroup: Root node is a leaf."
outgroup _ _ (Node _ _ [_]) = Left "outgroup: Root node has degree two."
outgroup _ _ (Node _ _ [_, _]) = Left "outgroup: Root node is bifurcating."
outgroup o r t@(Node b l ts)
  | duplicateLeaves t = Left "outgroup: Tree has duplicate leaves."
  | otherwise = do
    bip <- bp o (S.fromList lvs S.\\ o)
    rootAt bip t'
    lvs = leaves t
    (Node brO lbO tsO) = head ts
    -- Introduce a bifurcating root node.
    t' = Node b r [Node (split brO) lbO tsO, Node (split brO) l (tail ts)]

-- XXX: The 'midpoint' algorithm is pretty stupid because it calculates all
-- rooted trees and then finds the one minimizing the difference between the
-- heights of the left and right sub tree. Actually, one just needs to move left
-- or right, with the aim to minimize the height difference between the left and
-- right sub tree.

-- | Root tree at the midpoint.
-- Return 'Left' if
-- - the root node is not bifurcating.
midpoint :: (Semigroup e, Splittable e, Measurable e) => Tree e a -> Either String (Tree e a)
midpoint (Node _ _ []) = Left "midpoint: Root node is a leaf."
midpoint (Node _ _ [_]) = Left "midpoint: Root node has degree two."
midpoint t@(Node _ _ [_, _]) = getMidpoint <$> roots t
midpoint _ = Left "midpoint: Root node is multifurcating."

findMinIndex :: Ord a => [a] -> Int
findMinIndex (x : xs) = go (0, x) 1 xs
    go (i, _) _ [] = i
    go (i, z) j (y : ys) = if z < y then go (i, z) (j + 1) ys else go (j, y) (j + 1) ys
findMinIndex [] = error "findMinIndex: Empty list."

getMidpoint :: Measurable e => [Tree e a] -> Tree e a
getMidpoint ts = case t of
  (Node br lb [l, r]) ->
    let hl = height l
        hr = height r
        dh = (hl - hr) / 2
     in Node br lb [applyStem (subtract dh) l, applyStem (+ dh) r]
  -- Explicitly use 'error' here, because roots is supposed to return trees with
  -- bifurcating root nodes.
  _ -> error "getMidpoint: Root node is not bifurcating."
    dhs = map getDeltaHeight ts
    i = findMinIndex dhs
    t = ts !! i

-- find index of minimum; take this tree and move root to the midpoint of the branch

-- Get delta height of left and right sub tree.
getDeltaHeight :: Measurable e => Tree e a -> Double
getDeltaHeight (Node _ _ [l, r]) = abs $ height l - height r
-- Explicitly use 'error' here, because roots is supposed to return trees with
-- bifurcating root nodes.
getDeltaHeight _ = error "getDeltaHeight: Root node is not bifurcating."

-- | For a rooted tree with a bifurcating root node, get all possible rooted
-- trees.
-- The root node is moved.
-- For a tree with @l=2@ leaves, there is one rooted tree. For a bifurcating
-- tree with @l>2@ leaves, there are @(2l-3)@ rooted trees. For a general tree
-- with a bifurcating root node, and a total number of @n>2@ nodes, there are
-- (n-2) rooted trees.
-- Moving a multifurcating root node to another branch would change the degree
-- of the root node. Hence, a bifurcating root is required. To resolve a
-- multifurcating root, please use 'outgroup'.
-- Connect branches according to the provided 'Semigroup' instance.
-- Upon insertion of the root, split the affected branch into one out of two
-- equal entities according to a given function.
-- Return 'Left' if the root node is not 'bifurcating'.
roots :: (Semigroup e, Splittable e) => Tree e a -> Either String (Forest e a)
roots (Node _ _ []) = Left "roots: Root node is a leaf."
roots (Node _ _ [_]) = Left "roots: Root node has degree two."
roots t@(Node b c [tL, tR]) = Right $ t : descend b c tR tL ++ descend b c tL tR
roots _ = Left "roots: Root node is multifurcating."

complementaryForests :: Tree e a -> Forest e a -> [Forest e a]
complementaryForests t ts = [t : take i ts ++ drop (i + 1) ts | i <- [0 .. (n -1)]]
    n = length ts

-- From the bifurcating root, descend into one of the two pits.
-- descend splitFunction rootBranch rootLabel complementaryTree downwardsTree
descend :: (Semigroup e, Splittable e) => e -> a -> Tree e a -> Tree e a -> Forest e a
descend _ _ _ (Node _ _ []) = []
descend brR lbR tC (Node brD lbD tsD) =
  [ Node brR lbR [Node (split brDd) lbD f, Node (split brDd) lbDd tsDd]
    | (Node brDd lbDd tsDd, f) <- zip tsD cfs
    ++ concat
      [ descend brR lbR (Node (split brDd) lbD f) (Node (split brDd) lbDd tsDd)
        | (Node brDd lbDd tsDd, f) <- zip tsD cfs
    brC' = branch tC <> brD
    tC' = tC {branch = brC'}
    cfs = complementaryForests tC' tsD

-- | Root a tree at a specific position.
-- Root the tree at the branch defined by the given bipartition. The original
-- root node is moved to the new position.
-- The root node must be bifurcating (see 'roots' and 'outgroup').
-- Connect branches according to the provided 'Semigroup' instance.
-- Upon insertion of the root, split the affected branch according to the
-- provided 'Splittable' instance.
-- Return 'Left', if:
-- - the root node is not bifurcating;
-- - the tree has duplicate leaves;
-- - the bipartition does not match the leaves of the tree.
rootAt ::
  (Semigroup e, Splittable e, Eq a, Ord a) =>
  Bipartition a ->
  Tree e a ->
  Either String (Tree e a)
rootAt b t
  -- Tree is checked for being bifurcating in 'roots'.
  -- Do not use 'duplicateLeaves' here, because we also need to compare the leaf
  -- set with the bipartition.
  | length lvLst /= S.size lvSet = Left "rootAt: Tree has duplicate leaves."
  | toSet b /= lvSet = Left "rootAt: Bipartition does not match leaves of tree."
  | otherwise = rootAt' b t
    lvLst = leaves t
    lvSet = S.fromList $ leaves t

-- Assume the leaves of the tree are unique.
rootAt' ::
  (Semigroup e, Splittable e, Ord a) =>
  Bipartition a ->
  Tree e a ->
  Either String (Tree e a)
rootAt' b t = do
  ts <- roots t
  case find (\x -> Right b == bipartition x) ts of
    Nothing -> Left "rootAt': Bipartition not found on tree."
    Just t' -> Right t'

-- | Branch label for phylogenetic trees.
-- Branches may have a length and a support value.
data Phylo = Phylo
  { brLen :: Maybe BranchLength,
    brSup :: Maybe BranchSupport
  deriving (Read, Show, Eq, Ord, Generic, NFData)

instance Semigroup Phylo where
  Phylo mBL mSL <> Phylo mBR mSR =
      (getSum <$> (Sum <$> mBL) <> (Sum <$> mBR))
      (getMin <$> (Min <$> mSL) <> (Min <$> mSR))

instance ToJSON Phylo

instance FromJSON Phylo

-- | Branch length label.
-- For conversion, see 'phyloToLengthTree' and 'lengthToPhyloTree'.
newtype Length = Length {fromLength :: BranchLength}
  deriving (Read, Show, Eq, Ord, Generic, NFData)
  deriving (Num, Fractional, Floating) via Double
  deriving (Semigroup, Monoid) via Sum Double

instance Measurable Length where
  getLen = fromLength
  setLen b _ = Length b

instance Splittable Length where
  split = Length . (/ 2.0) . fromLength

instance ToJSON Length

instance FromJSON Length

-- | If root branch length is not available, set it to 0.
-- Return 'Left' if any other branch length is unavailable.
phyloToLengthTree :: Tree Phylo a -> Either String (Tree Length a)
phyloToLengthTree =
  maybe (Left "phyloToLengthTree: Length unavailable for some branches.") Right
    . bitraverse toLength pure
    . cleanRootLength

cleanRootLength :: Tree Phylo a -> Tree Phylo a
cleanRootLength (Node (Phylo Nothing s) l f) = Node (Phylo (Just 0) s) l f
cleanRootLength t = t

toLength :: Phylo -> Maybe Length
toLength p = Length <$> brLen p

-- | Set all branch support values to 'Nothing'.
-- Useful, for example, to export a tree with branch lengths in Newick format.
lengthToPhyloTree :: Tree Length a -> Tree Phylo a
lengthToPhyloTree = first fromLengthLabel

fromLengthLabel :: Length -> Phylo
fromLengthLabel (Length b) = Phylo (Just b) Nothing

-- | Branch support label.
-- For conversion, see 'phyloToSupportTree'.
newtype Support = Support {fromSupport :: BranchSupport}
  deriving (Read, Show, Eq, Ord, Generic, NFData)
  deriving (Num, Fractional, Floating) via Double
  deriving (Semigroup) via Min Double

instance Supported Support where
  getSup = fromSupport
  setSup s _ = Support s

instance Splittable Support where
  split = id

instance ToJSON Support

instance FromJSON Support

-- | Set branch support values of branches leading to the leaves and of the root
-- branch to maximum support.
-- Return 'Left' if any other branch has no available support value.
phyloToSupportTree :: Tree Phylo a -> Either String (Tree Support a)
phyloToSupportTree t =
    (Left "phyloToSupportTree: Support unavailable for some branches.")
    $ bitraverse toSupport pure $
      cleanLeafSupport m $
        cleanRootSupport m t
    m = getMaxSupport t

-- | Set all unavailable branch support values to maximum support.
phyloToSupportTreeUnsafe :: Tree Phylo a -> Tree Support a
phyloToSupportTreeUnsafe t = cleanSupport m t
    m = getMaxSupport t

-- If all branch support values are below 1.0, set the max support to 1.0.
getMaxSupport :: Tree Phylo a -> BranchSupport
getMaxSupport = fromJust . max (Just 1.0) . bimaximum . bimap brSup (const Nothing)

cleanRootSupport :: BranchSupport -> Tree Phylo a -> Tree Phylo a
cleanRootSupport maxSup (Node (Phylo b Nothing) l xs) = Node (Phylo b (Just maxSup)) l xs
cleanRootSupport _ t = t

cleanLeafSupport :: BranchSupport -> Tree Phylo a -> Tree Phylo a
cleanLeafSupport s (Node (Phylo b Nothing) l []) = Node (Phylo b (Just s)) l []
cleanLeafSupport s (Node b l xs) = Node b l $ map (cleanLeafSupport s) xs

toSupport :: Phylo -> Maybe Support
toSupport (Phylo _ Nothing) = Nothing
toSupport (Phylo _ (Just s)) = Just $ Support s

cleanSupport :: BranchSupport -> Tree Phylo a -> Tree Support a
cleanSupport maxSup (Node (Phylo _ s) l xs) = Node (Support $ fromMaybe maxSup s) l $ map (cleanSupport maxSup) xs

-- XXX: Probably change name. Strict is reserved for "not lazy".

-- | Strict branch label for phylogenetic trees.
data PhyloStrict = PhyloStrict
  { sBrLen :: BranchLength,
    sBrSup :: BranchSupport
  deriving (Read, Show, Eq, Ord, Generic)

instance Semigroup PhyloStrict where
  PhyloStrict bL sL <> PhyloStrict bR sR = PhyloStrict (bL + bR) (min sL sR)

instance Measurable PhyloStrict where
  getLen = sBrLen
  setLen b l = l {sBrLen = b}

instance Splittable PhyloStrict where
  split l = l {sBrLen = b'}
      b' = sBrLen l / 2.0

instance Supported PhyloStrict where
  getSup = sBrSup
  setSup s l = l {sBrSup = s}

instance ToJSON PhyloStrict

instance FromJSON PhyloStrict

-- | Conversion to a 'PhyloStrict' tree.
-- See 'phyloToLengthTree' and 'phyloToSupportTree'.
toStrictTree :: Tree Phylo a -> Either String (Tree PhyloStrict a)
toStrictTree t = do
  lt <- first fromLength <$> phyloToLengthTree t
  st <- first fromSupport <$> phyloToSupportTree t
  case zipTreesWith PhyloStrict const lt st of
    Nothing -> error "toStrictTree: This is a bug. Can not zip two trees with the same topology."
    Just zt -> return zt

-- | Set all branch length and support values to 'Just' the value.
-- Useful, for example, to export a tree with branch lengths in Newick format.
fromStrictTree :: Tree PhyloStrict a -> Tree Phylo a
fromStrictTree = first fromStrictLabel

fromStrictLabel :: PhyloStrict -> Phylo
fromStrictLabel (PhyloStrict b s) = Phylo (Just b) (Just s)