{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
module ELynx.Tree.Phylogeny
  ( 
    equal,
    intersect,
    bifurcating,
    outgroup,
    midpoint,
    roots,
    rootAt,
    
    Phylo (..),
    toPhyloTree,
    measurableToPhyloTree,
    supportedToPhyloTree,
    Length (..),
    phyloToLengthTree,
    Support (..),
    phyloToSupportTree,
    phyloToSupportTreeUnsafe,
    PhyloExplicit (..),
    toExplicitTree,
  )
where
import Control.DeepSeq
import Data.Aeson
import Data.Bifoldable
import Data.Bifunctor
import Data.Bitraversable
import Data.List hiding (intersect)
import Data.Maybe
import Data.Monoid
import Data.Semigroup
import Data.Set (Set)
import qualified Data.Set as S
import ELynx.Tree.Bipartition
import ELynx.Tree.Measurable
import ELynx.Tree.Rooted
import ELynx.Tree.Splittable
import ELynx.Tree.Supported
import GHC.Generics
equal :: (Eq e, Eq a) => Tree e a -> Tree e a -> Bool
equal ~(Node brL lbL tsL) ~(Node brR lbR tsR) =
  (brL == brR)
    && (lbL == lbR)
    && (length tsL == length tsR)
    && all (`elem` tsR) tsL
intersect ::
  (Semigroup e, Eq e, Ord a) => Forest e a -> Either String (Forest e a)
intersect ts
  | S.null lvsCommon = Left "intersect: Intersection of leaves is empty."
  | otherwise = case sequence [dropLeavesWith (predicate ls) t | (ls, t) <- zip leavesToDrop ts] of
    Nothing -> Left "intersect: A tree is empty."
    Just ts' -> Right ts'
  where
    
    lvss = map (S.fromList . leaves) ts
    
    lvsCommon = foldl1' S.intersection lvss
    
    leavesToDrop = map (S.\\ lvsCommon) lvss
    
    predicate lvsToDr l = l `S.member` lvsToDr
bifurcating :: Tree e a -> Bool
bifurcating (Node _ _ []) = True
bifurcating (Node _ _ [x, y]) = bifurcating x && bifurcating y
bifurcating _ = False
outgroup :: (Semigroup e, Splittable e, Ord a) => Set a -> a -> Tree e a -> Either String (Tree e a)
outgroup _ _ (Node _ _ []) = Left "outgroup: Root node is a leaf."
outgroup _ _ (Node _ _ [_]) = Left "outgroup: Root node has degree two."
outgroup _ _ (Node _ _ [_, _]) = Left "outgroup: Root node is bifurcating."
outgroup o r t@(Node b l ts)
  | duplicateLeaves t = Left "outgroup: Tree has duplicate leaves."
  | otherwise = do
    bip <- bp o (S.fromList lvs S.\\ o)
    rootAt bip t'
  where
    lvs = leaves t
    (Node brO lbO tsO) = head ts
    
    t' = Node b r [Node (split brO) lbO tsO, Node (split brO) l (tail ts)]
midpoint :: (Semigroup e, Splittable e, Measurable e) => Tree e a -> Either String (Tree e a)
midpoint (Node _ _ []) = Left "midpoint: Root node is a leaf."
midpoint (Node _ _ [_]) = Left "midpoint: Root node has degree two."
midpoint t@(Node _ _ [_, _]) = getMidpoint <$> roots t
midpoint _ = Left "midpoint: Root node is multifurcating."
findMinIndex :: Ord a => [a] -> Int
findMinIndex (x : xs) = go (0, x) 1 xs
  where
    go (i, _) _ [] = i
    go (i, z) j (y : ys) = if z < y then go (i, z) (j + 1) ys else go (j, y) (j + 1) ys
findMinIndex [] = error "findMinIndex: Empty list."
getMidpoint :: Measurable e => [Tree e a] -> Tree e a
getMidpoint ts = case t of
  (Node br lb [l, r]) ->
    let hl = height l
        hr = height r
        dh = (hl - hr) / 2
     in Node br lb [applyStem (subtract dh) l, applyStem (+ dh) r]
  
  
  _ -> error "getMidpoint: Root node is not bifurcating."
  where
    dhs = map getDeltaHeight ts
    i = findMinIndex dhs
    t = ts !! i
getDeltaHeight :: Measurable e => Tree e a -> Double
getDeltaHeight (Node _ _ [l, r]) = abs $ height l - height r
getDeltaHeight _ = error "getDeltaHeight: Root node is not bifurcating."
roots :: (Semigroup e, Splittable e) => Tree e a -> Either String (Forest e a)
roots (Node _ _ []) = Left "roots: Root node is a leaf."
roots (Node _ _ [_]) = Left "roots: Root node has degree two."
roots t@(Node b c [tL, tR]) = Right $ t : descend b c tR tL ++ descend b c tL tR
roots _ = Left "roots: Root node is multifurcating."
complementaryForests :: Tree e a -> Forest e a -> [Forest e a]
complementaryForests t ts = [t : take i ts ++ drop (i + 1) ts | i <- [0 .. (n -1)]]
  where
    n = length ts
descend :: (Semigroup e, Splittable e) => e -> a -> Tree e a -> Tree e a -> Forest e a
descend _ _ _ (Node _ _ []) = []
descend brR lbR tC (Node brD lbD tsD) =
  [ Node brR lbR [Node (split brDd) lbD f, Node (split brDd) lbDd tsDd]
    | (Node brDd lbDd tsDd, f) <- zip tsD cfs
  ]
    ++ concat
      [ descend brR lbR (Node (split brDd) lbD f) (Node (split brDd) lbDd tsDd)
        | (Node brDd lbDd tsDd, f) <- zip tsD cfs
      ]
  where
    brC' = branch tC <> brD
    tC' = tC {branch = brC'}
    cfs = complementaryForests tC' tsD
rootAt ::
  (Semigroup e, Splittable e, Eq a, Ord a) =>
  Bipartition a ->
  Tree e a ->
  Either String (Tree e a)
rootAt b t
  
  
  
  
  | length lvLst /= S.size lvSet = Left "rootAt: Tree has duplicate leaves."
  | toSet b /= lvSet = Left "rootAt: Bipartition does not match leaves of tree."
  | otherwise = rootAt' b t
  where
    lvLst = leaves t
    lvSet = S.fromList $ leaves t
rootAt' ::
  (Semigroup e, Splittable e, Ord a) =>
  Bipartition a ->
  Tree e a ->
  Either String (Tree e a)
rootAt' b t = do
  ts <- roots t
  case find (\x -> Right b == bipartition x) ts of
    Nothing -> Left "rootAt': Bipartition not found on tree."
    Just t' -> Right t'
data Phylo = Phylo
  { brLen :: Maybe BranchLength,
    brSup :: Maybe BranchSupport
  }
  deriving (Read, Show, Eq, Ord, Generic, NFData)
instance Semigroup Phylo where
  Phylo mBL mSL <> Phylo mBR mSR =
    Phylo
      (getSum <$> (Sum <$> mBL) <> (Sum <$> mBR))
      (getMin <$> (Min <$> mSL) <> (Min <$> mSR))
instance ToJSON Phylo
instance FromJSON Phylo
toPhyloTree :: (Measurable e, Supported e) => Tree e a -> Tree Phylo a
toPhyloTree = first toPhyloLabel
toPhyloLabel :: (Measurable e, Supported e) => e -> Phylo
toPhyloLabel x = Phylo (Just $ getLen x) (Just $ getSup x)
measurableToPhyloTree :: Measurable e => Tree e a -> Tree Phylo a
measurableToPhyloTree = first measurableToPhyloLabel
measurableToPhyloLabel :: Measurable e => e -> Phylo
measurableToPhyloLabel x = Phylo (Just $ getLen x) Nothing
supportedToPhyloTree :: Supported e => Tree e a -> Tree Phylo a
supportedToPhyloTree = first supportedToPhyloLabel
supportedToPhyloLabel :: Supported e => e -> Phylo
supportedToPhyloLabel x = Phylo Nothing (Just $ getSup x)
newtype Length = Length {fromLength :: BranchLength}
  deriving (Read, Show, Eq, Ord, Generic, NFData)
  deriving (Num, Fractional, Floating) via Double
  deriving (Semigroup, Monoid) via Sum Double
instance Measurable Length where
  getLen = fromLength
  setLen b _ = Length b
instance Splittable Length where
  split = Length . (/ 2.0) . fromLength
instance ToJSON Length
instance FromJSON Length
phyloToLengthTree :: Tree Phylo a -> Either String (Tree Length a)
phyloToLengthTree =
  maybe (Left "phyloToLengthTree: Length unavailable for some branches.") Right
    . bitraverse toLength pure
    . cleanRootLength
cleanRootLength :: Tree Phylo a -> Tree Phylo a
cleanRootLength (Node (Phylo Nothing s) l f) = Node (Phylo (Just 0) s) l f
cleanRootLength t = t
toLength :: Phylo -> Maybe Length
toLength p = Length <$> brLen p
newtype Support = Support {fromSupport :: BranchSupport}
  deriving (Read, Show, Eq, Ord, Generic, NFData)
  deriving (Num, Fractional, Floating) via Double
  deriving (Semigroup) via Min Double
instance Supported Support where
  getSup = fromSupport
  setSup s _ = Support s
instance Splittable Support where
  split = id
instance ToJSON Support
instance FromJSON Support
phyloToSupportTree :: Tree Phylo a -> Either String (Tree Support a)
phyloToSupportTree t =
  maybe
    (Left "phyloToSupportTree: Support unavailable for some branches.")
    Right
    $ bitraverse toSupport pure $
      cleanLeafSupport m $
        cleanRootSupport m t
  where
    m = getMaxSupport t
phyloToSupportTreeUnsafe :: Tree Phylo a -> Tree Support a
phyloToSupportTreeUnsafe t = cleanSupport m t
  where
    m = getMaxSupport t
getMaxSupport :: Tree Phylo a -> BranchSupport
getMaxSupport = fromJust . max (Just 1.0) . bimaximum . bimap brSup (const Nothing)
cleanRootSupport :: BranchSupport -> Tree Phylo a -> Tree Phylo a
cleanRootSupport maxSup (Node (Phylo b Nothing) l xs) = Node (Phylo b (Just maxSup)) l xs
cleanRootSupport _ t = t
cleanLeafSupport :: BranchSupport -> Tree Phylo a -> Tree Phylo a
cleanLeafSupport s (Node (Phylo b Nothing) l []) = Node (Phylo b (Just s)) l []
cleanLeafSupport s (Node b l xs) = Node b l $ map (cleanLeafSupport s) xs
toSupport :: Phylo -> Maybe Support
toSupport (Phylo _ Nothing) = Nothing
toSupport (Phylo _ (Just s)) = Just $ Support s
cleanSupport :: BranchSupport -> Tree Phylo a -> Tree Support a
cleanSupport maxSup (Node (Phylo _ s) l xs) = Node (Support $ fromMaybe maxSup s) l $ map (cleanSupport maxSup) xs
data PhyloExplicit = PhyloExplicit
  { sBrLen :: BranchLength,
    sBrSup :: BranchSupport
  }
  deriving (Read, Show, Eq, Ord, Generic)
instance Semigroup PhyloExplicit where
  PhyloExplicit bL sL <> PhyloExplicit bR sR = PhyloExplicit (bL + bR) (min sL sR)
instance Measurable PhyloExplicit where
  getLen = sBrLen
  setLen b l = l {sBrLen = b}
instance Splittable PhyloExplicit where
  split l = l {sBrLen = b'}
    where
      b' = sBrLen l / 2.0
instance Supported PhyloExplicit where
  getSup = sBrSup
  setSup s l = l {sBrSup = s}
instance ToJSON PhyloExplicit
instance FromJSON PhyloExplicit
toExplicitTree :: Tree Phylo a -> Either String (Tree PhyloExplicit a)
toExplicitTree t = do
  lt <- first fromLength <$> phyloToLengthTree t
  st <- first fromSupport <$> phyloToSupportTree t
  case zipTreesWith PhyloExplicit const lt st of
    Nothing -> error "toExplicitTree: This is a bug. Can not zip two trees with the same topology."
    Just zt -> return zt