{-# OPTIONS_GHC -XFlexibleContexts #-}

module Bio.GFF3.FeatureHier ( FeatureHier, features, lookupId, lookupIdChildren
                            , fromList, insert, delete
                            , parents, children, parentsM, childrenM
                            , checkInvariants
                            )
    where 

import Control.Monad.Error
import Control.Monad.Reader
import qualified Data.ByteString.Lazy.Char8 as LBS
import Data.Graph
import Data.List hiding (insert, delete)
import qualified Data.List as List (delete)
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Set as S

import Bio.GFF3.Feature

data FeatureHier = FeatureHier { features :: !(S.Set Feature)
                               , idToFeature :: !(M.Map LBS.ByteString Feature)
                               , idToChildren :: !(M.Map LBS.ByteString [Feature])
                               } deriving (Show)

empty :: FeatureHier
empty = FeatureHier { features = S.empty, idToFeature = M.empty, idToChildren = M.empty }

lookupId :: (Error e, MonadError e m) => FeatureHier -> LBS.ByteString -> m Feature
lookupId fh idstr = maybe idNotFound return $ M.lookup idstr $ idToFeature fh
    where idNotFound = throwError $ strMsg $ "lookupId: feature ID " ++ show (LBS.unpack idstr) ++ " not found"

lookupIdChildren :: (Error e, MonadError e m) => FeatureHier -> LBS.ByteString -> m [Feature]
lookupIdChildren fh idstr = maybe idNotFound return $ M.lookup idstr $ idToChildren fh
    where idNotFound = throwError $ strMsg $ "lookupIdChildren: feature ID " ++ show (LBS.unpack idstr) ++ " not found"

fromList :: (Error e, MonadError e m) => [Feature] -> m FeatureHier
fromList feats = foldM (flip insert) empty $ parentsFirst feats

-- topSort, i < j when i --> j and j /-> i
-- parents first mean edges go parent --> child
parentsFirst :: [Feature] -> [Feature]
parentsFirst feats = let !(graph, fromVertex, _) = graphFromEdges $ featureEdges feats
                     in map (toFeature fromVertex) $ topSort graph
    where toFeature fromVertex v = case fromVertex v of (feat, _, _) -> feat

-- Feature -> (Feature, its key, keys of its out-edges)
featureEdges :: [Feature] -> [(Feature, Int, [Int])]
featureEdges feats = map featureEdge $ keyedFeats
    where keyedFeats = zip feats [1..]
          idToChildKeys = foldl' insertParentsToKey M.empty keyedFeats
          insertParentsToKey m0 (f, k) = foldl' (insertParentToKey k) m0 $ parentIds f
          insertParentToKey k m0 pid = M.insertWith' (++) pid [k] m0
          featureEdge (feat, key) = (feat, key, concatMap childKeys $ ids feat)
          childKeys pid = M.findWithDefault [] pid idToChildKeys


insert :: (Error e, MonadError e m) => Feature -> FeatureHier -> m FeatureHier
insert f hier0
    = liftM3 FeatureHier doNewAllFeatures doNewIdToFeature doNewIdToChildren
    where doNewAllFeatures = return $ S.insert f $ features hier0
          doNewIdToFeature = foldM insertIdToFeature (idToFeature hier0) $ ids f
          insertIdToFeature m0 fid = if M.member fid m0
                                     then throwError $ strMsg $ "insertFeature: Duplicate ID " ++ show fid
                                     else return $ M.insert fid f m0
          doNewIdToChildren = foldM insertIdToChildren (idToChildren hier0) $ parentIds f
          insertIdToChildren m0 pid = if M.member pid $ idToFeature hier0
                                      then return $ M.insertWith' (++) pid [f] m0
                                      else throwError $ strMsg $ "insertFeature: Parent ID " ++ show pid ++ " not present"

delete :: (Error e, MonadError e m) => Feature -> FeatureHier -> m FeatureHier
delete f hier0
    = liftM3 FeatureHier doNewAllFeatures doNewIdToFeature doNewIdToChildren
    where doNewAllFeatures = if S.member f $ features hier0
                             then return $ S.delete f $ features hier0
                             else throwError $ strMsg $ "deleteFeature: Feature not present " ++ show f
          doNewIdToFeature = foldM deleteIdToFeature (idToFeature hier0) $ ids f
          deleteIdToFeature m0 fid = if M.member fid m0
                                     then return $ M.delete fid m0
                                     else throwError $ strMsg $ "deleteFeature: ID not present for feature " ++ show (fid, f)
          doNewIdToChildren = if any (flip M.member $ idToChildren hier0) $ ids f
                              then throwError $ strMsg $ "deleteFeature: Feature has children: " ++ show f
                              else foldM deleteIdToChildren (idToChildren hier0) $ parentIds f
          deleteIdToChildren m0 pid = if maybe False (elem f) $ M.lookup pid m0
                                      then throwError $ strMsg $ "deleteFeature: Child not present for parent ID " ++ show (f, pid)
                                      else return $ M.adjust (List.delete f) pid m0

parents :: FeatureHier -> Feature -> [Feature]
parents fh f = map parent $ parentIds f
    where parent pid = M.findWithDefault noParent pid $ idToFeature fh
              where noParent = error $ "featParents: Unknown parent id " ++ show pid ++ " for " ++ show f

children :: FeatureHier -> Feature -> [Feature]
children fh f = concatMap chs $ ids f
    where chs fid = M.findWithDefault [] fid $ idToChildren fh

parentsM :: (MonadReader FeatureHier m) => Feature -> m [Feature]
parentsM = asks . flip parents

childrenM :: (MonadReader FeatureHier m) => Feature -> m [Feature]
childrenM = asks . flip children

checkInvariants :: FeatureHier -> [String]
checkInvariants hier
    = concat [ checkAllFeatureIDs, checkAllFeatureParents, checkFeaturesInAll, checkChildrenInAll ]
    where checkAllFeatureIDs = concatMap checkFeatureIDs $ S.toList $ features hier
          checkFeatureIDs f = concatMap checkFeatureID $ ids f
              where checkFeatureID fid = case M.lookup fid $ idToFeature hier of
                                           Nothing -> ["Feature ID " ++ show fid ++ " missing for " ++ show f]
                                           Just f' | f' /= f -> ["Wrong feature " ++ show f' ++ " /= " ++ show f ++ " for ID " ++ show fid]
                                                   | otherwise -> []
          checkAllFeatureParents = []
          checkFeaturesInAll = []
          checkChildrenInAll = []