module Swish.GraphPartition
    ( PartitionedGraph(..), getArcs, getPartitions
    , GraphPartition(..), node, toArcs
    , partitionGraph, comparePartitions
    , partitionShowP
    )
where
import Swish.GraphClass (Label(..), Arc(..))
import Control.Monad.State (MonadState(..), State)
import Control.Monad.State (evalState)
import Data.List (foldl', partition)
import Data.List.NonEmpty (NonEmpty(..), (<|))
import Data.Maybe (mapMaybe)
import qualified Data.List.NonEmpty as NE
data PartitionedGraph lb = PartitionedGraph [GraphPartition lb]
    deriving (Eq, Show)
getArcs :: PartitionedGraph lb -> [Arc lb]
getArcs (PartitionedGraph ps) = concatMap toArcs ps
getPartitions :: PartitionedGraph lb -> [GraphPartition lb]
getPartitions (PartitionedGraph ps) = ps
data GraphPartition lb
    = PartObj lb
    | PartSub lb (NonEmpty (lb,GraphPartition lb))
node :: GraphPartition lb -> lb
node (PartObj ob)   = ob
node (PartSub sb _) = sb
toArcs :: GraphPartition lb -> [Arc lb]
toArcs (PartObj _)      = []
toArcs (PartSub sb prs) = concatMap toArcs1 $ NE.toList prs
    where
        toArcs1 (pr,ob) = Arc sb pr (node ob) : toArcs ob
instance (Label lb) => Eq (GraphPartition lb) where
    (PartObj o1)    == (PartObj o2)    = o1 == o2
    (PartSub s1 p1) == (PartSub s2 p2) = s1 == s2 && p1 == p2
    _               == _               = False
instance (Label lb) => Ord (GraphPartition lb) where
    (PartSub s1 p1) `compare` (PartSub s2 p2) = (s1,p1) `compare` (s2,p2)
    (PartObj o1)    `compare` (PartObj o2)    = o1 `compare` o2
    (PartSub _ _)   `compare` _               = LT
    _               `compare` (PartSub _ _)   = GT
instance (Label lb) => Show (GraphPartition lb) where
    show = partitionShow
partitionShow :: (Label lb) => GraphPartition lb -> String
partitionShow (PartObj ob)             = show ob
partitionShow (PartSub sb (pr :| prs)) =
    "("++ show sb ++ " " ++ showpr pr ++ concatMap ((" ; "++).showpr) prs ++ ")"
    where
        showpr (a,b) = show a ++ " " ++ show b
partitionShowP ::
    (Label lb) =>
    String
    -> GraphPartition lb
    -> String
partitionShowP _    (PartObj ob)             = show ob
partitionShowP pref (PartSub sb (pr :| prs)) =
    pref++"("++ show sb ++ " " ++ showpr pr ++ concatMap (((pref++"  ; ")++).showpr) prs ++ ")"
    where
        showpr (a,b) = show a ++ " " ++ partitionShowP (pref++"  ") b
partitionGraph :: (Label lb) => [Arc lb] -> PartitionedGraph lb
partitionGraph [] = PartitionedGraph []
partitionGraph arcs =
    makePartitions fixs topv1 intv1
    where
        (fixs,vars)  = partition isNonVar $ collect arcSubj arcs
        vars1        = collectMore arcObj arcs vars
        (intv,topv)  = partition objOnce vars1
        intv1        = map stripObj intv
        topv1        = map stripObj topv
        isNonVar     = not . labelIsVar . fst
        objOnce      = isSingle . snd . snd
        isSingle [_] = True
        isSingle _   = False
        stripObj (k,(s,_)) = (k,s)
type LabelledArcs lb = (lb, NonEmpty (Arc lb))
type LabelledPartition lb = (lb, GraphPartition lb)
type MakePartitionState lb = ([LabelledArcs lb], [LabelledArcs lb], [LabelledArcs lb])
type PState lb = State (MakePartitionState lb)
makePartitions ::
    (Eq lb) =>
    [LabelledArcs lb]
    -> [LabelledArcs lb]
    -> [LabelledArcs lb]
    -> PartitionedGraph lb
makePartitions fixs topv intv =
    PartitionedGraph $ evalState (makePartitions1 []) (fixs,topv,intv)
makePartitions1 ::
    (Eq lb) =>
    [LabelledArcs lb]
    -> PState lb [GraphPartition lb]
makePartitions1 [] = do
    s <- pickNextSubject
    if null s then return [] else makePartitions1 s
makePartitions1 (sub:subs) = do
    ph <- makePartitions2 sub
    pt <- makePartitions1 subs
    return $ ph++pt
makePartitions2 ::
    (Eq lb) =>
    LabelledArcs lb
    -> PState lb [GraphPartition lb]
makePartitions2 subs = do
    (part,moresubs) <- makeStatements subs
    moreparts <- if null moresubs
                 then return []
                 else makePartitions1 moresubs
    return $ part:moreparts
makeStatements ::
    (Eq lb) =>
    LabelledArcs lb
    -> PState lb (GraphPartition lb, [LabelledArcs lb])
makeStatements (sub,stmts) = do
    propmore <- mapM makeStatement (NE.toList stmts)
    let (props,moresubs) = unzip propmore
    return (PartSub sub (NE.fromList props), concat moresubs)
    
makeStatement ::
    (Eq lb) =>
    Arc lb
    -> PState lb (LabelledPartition lb, [LabelledArcs lb])
makeStatement (Arc _ prop obj) = do
    intobj <- pickIntSubject obj
    (gpobj, moresubs) <- if null intobj
                         then do
                             ms <- pickVarSubject obj
                             return (PartObj obj,ms)
                         else makeStatements (head intobj)
    return ((prop,gpobj), moresubs)
pickNextSubject :: PState lb [LabelledArcs lb]
pickNextSubject = do
    (a1,a2,a3) <- get
    let (s,st) = case (a1,a2,a3) of
                   (s1h:s1t,s2,s3) -> ([s1h],(s1t,s2,s3))
                   ([],s2h:s2t,s3) -> ([s2h],([],s2t,s3))
                   ([],[],s3h:s3t) -> ([s3h],([],[],s3t))
                   ([],[],[])      -> ([]   ,([],[],[] ))
    put st
    return s
pickIntSubject :: (Eq lb) =>
    lb
    -> PState lb [LabelledArcs lb]
pickIntSubject sub = do
    (s1,s2,s3) <- get
    let varsub = removeBy (\x->(x==).fst) sub s3
    case varsub of
        Just (vs, s3new) -> put (s1,s2,s3new) >> return [vs]
        Nothing          -> return []
pickVarSubject ::
    (Eq lb) =>
    lb ->
    PState lb [LabelledArcs lb]
pickVarSubject sub = do
    (s1,s2,s3) <- get
    let varsub = removeBy (\x->(x==).fst) sub s2
    case varsub of
        Just (vs, s2new) -> put (s1,s2new,s3) >> return [vs]
        _                -> return []
comparePartitions :: (Label lb) =>
    PartitionedGraph lb
    -> PartitionedGraph lb
    -> [(Maybe (GraphPartition lb), Maybe (GraphPartition lb))]
comparePartitions (PartitionedGraph gp1) (PartitionedGraph gp2) =
    comparePartitions1 (reverse gp1) (reverse gp2)
comparePartitions1 :: (Label lb) =>
    [GraphPartition lb]
    -> [GraphPartition lb]
    -> [(Maybe (GraphPartition lb),Maybe (GraphPartition lb))]
comparePartitions1 pg1 pg2 =
        ds ++ [ (Just r1p,Nothing) | r1p<-r1 ]
           ++ [ (Nothing,Just r2p) | r2p<-r2 ]
    where
        (ds,r1,r2) = listDifferences comparePartitions2 pg1 pg2
comparePartitions2 :: (Label lb) =>
    GraphPartition lb
    -> GraphPartition lb
    -> Maybe [(Maybe (GraphPartition lb), Maybe (GraphPartition lb))]
comparePartitions2 (PartObj l1) (PartObj l2) =
    if matchNodes l1 l2 then Just [] else Nothing
comparePartitions2 pg1@(PartSub l1 p1s) pg2@(PartSub l2 p2s) =
    if match then comp1 else Nothing
    where
        comp1  = case comparePartitions3 l1 l2 p1s p2s of
                    Nothing -> if matchVar then Nothing
                                           else Just [(Just pg1,Just pg2)]
                    Just [] -> Just []
                    Just ps ->  Just ps
        matchVar = labelIsVar l1 && labelIsVar l2
        match    = matchVar || l1 == l2
comparePartitions2 pg1 pg2 =
    if not (labelIsVar l1) && l1 == l2
        then Just [(Just pg1,Just pg2)]
        else Nothing
    where
        l1 = node pg1
        l2 = node pg2
comparePartitions3 :: (Label lb) =>
    lb
    -> lb
    -> NonEmpty (LabelledPartition lb)
    -> NonEmpty (LabelledPartition lb)
    -> Maybe [(Maybe (GraphPartition lb),Maybe (GraphPartition lb))]
comparePartitions3 l1 l2 s1s s2s = Just $
        ds ++ [ (Just (PartSub l1 (r1p :| [])),Nothing) | r1p<-r1 ]
           ++ [ (Nothing,Just (PartSub l2 (r2p :| []))) | r2p<-r2 ]
    where
        (ds,r1,r2) = listDifferences
                     (comparePartitions4 l1 l2)
                     (NE.toList s1s)
                     (NE.toList s2s)
comparePartitions4 :: (Label lb) =>
    lb
    -> lb
    -> LabelledPartition lb
    -> LabelledPartition lb
    -> Maybe [(Maybe (GraphPartition lb),Maybe (GraphPartition lb))]
comparePartitions4 _ _ (p1,o1) (p2,o2) =
    if matchNodes p1 p2 then comp1 else Nothing
    where
        comp1   = case comparePartitions2 o1 o2 of
                    Nothing -> Just [(Just o1,Just o2)]
                    ds      -> ds
matchNodes :: (Label lb) => lb -> lb -> Bool
matchNodes l1 l2
    | labelIsVar l1 = labelIsVar l2
    | otherwise     = l1 == l2
collect :: (Eq b) => (a->b) -> [a] -> [(b, NonEmpty a)]
collect = collectBy (==)
collectBy :: (b->b->Bool) -> (a->b) -> [a] -> [(b, NonEmpty a)]
collectBy cmp sel = map reverseCollection . collectBy1 cmp sel []
collectBy1 :: (b->b->Bool) -> (a->b) -> [(b, NonEmpty a)] -> [a] -> [(b, NonEmpty a)]
collectBy1 _   _   sofar []     = sofar
collectBy1 cmp sel sofar (a:as) =
    collectBy1 cmp sel (collectBy2 cmp sel a sofar) as
collectBy2 :: (b->b->Bool) -> (a->b) -> a -> [(b, NonEmpty a)] -> [(b, NonEmpty a)]
collectBy2 _   sel a [] = [(sel a, a :| [])]
collectBy2 cmp sel a (col@(k,as) : cols)
    | cmp ka k  = (k, a <| as) : cols
    | otherwise = col : collectBy2 cmp sel a cols
    where
        ka = sel a
reverseCollection :: (b, NonEmpty a) -> (b, NonEmpty a)
reverseCollection (k,as) = (k, NE.reverse as)
collectMore :: (Eq b) => (a->b) -> [a] -> [(b,c)] -> [(b,(c,[a]))]
collectMore = collectMoreBy (==)
collectMoreBy ::
    (b->b->Bool) -> (a->b) -> [a] -> [(b,c)] -> [(b,(c,[a]))]
collectMoreBy cmp sel as cols =
    map reverseMoreCollection $
    collectMoreBy1 cmp sel as (map (\ (b,cs) -> (b,(cs,[])) ) cols)
collectMoreBy1 ::
    (b->b->Bool) -> (a->b) -> [a] -> [(b,(c,[a]))] -> [(b,(c,[a]))]
collectMoreBy1 _   _   []     cols = cols
collectMoreBy1 cmp sel (a:as) cols =
    collectMoreBy1 cmp sel as (collectMoreBy2 cmp sel a cols)
collectMoreBy2 ::
    (b->b->Bool) -> (a->b) -> a -> [(b,(c,[a]))] -> [(b,(c,[a]))]
collectMoreBy2 _   _   _ [] = []
collectMoreBy2 cmp sel a (col@(k,(b,as)):cols)
    | cmp (sel a) k = (k,(b, a:as)):cols
    | otherwise     = col:collectMoreBy2 cmp sel a cols
reverseMoreCollection :: (b,(c,[a])) -> (b,(c,[a]))
reverseMoreCollection (k,(c,as)) = (k,(c,reverse as))
removeBy :: (b->a->Bool) -> b -> [a] -> Maybe (a,[a])
removeBy cmp a0 as = removeBy1 cmp a0 as []
removeBy1 :: (b->a->Bool) -> b -> [a] -> [a] -> Maybe (a,[a])
removeBy1 _   _  []     _     = Nothing
removeBy1 cmp a0 (a:as) sofar
    | cmp a0 a  = Just (a,reverseTo sofar as)
    | otherwise = removeBy1 cmp a0 as (a:sofar)
reverseTo :: [a] -> [a] -> [a]
reverseTo front back = foldl' (flip (:)) back front
removeEach :: [a] -> [(a,[a])]
removeEach [] = []
removeEach (a:as) = (a,as):[ (a1,a:a1s) | (a1,a1s) <- removeEach as ]
listDifferences :: (a->a->Maybe [d]) -> [a] -> [a] -> ([d],[a],[a])
listDifferences _   []       a2s = ([],[],a2s)
listDifferences cmp (a1:a1t) a2s =
    case mcomp of
        Nothing       -> morediffs [] [a1] a1t a2s
        Just (ds,a2t) -> morediffs ds []   a1t a2t
    where
        
        
        
        
        mcomp = choose $ mapMaybe maybeResult comps
        comps = [ (cmp a1 a2,a2t) | (a2,a2t) <- removeEach a2s ]
        maybeResult (Nothing,_)   = Nothing
        maybeResult (Just ds,a2t) = Just (ds,a2t)
        morediffs xds xa1h xa1t xa2t  = (xds++xds1,xa1h++xa1r,xa2r)
            where
                (xds1,xa1r,xa2r) = listDifferences cmp xa1t xa2t
        choose  []       = Nothing
        choose  ds@(d:_) = choose1 d ds
        choose1 _ (d@([],_):_)  = Just d
        choose1 d []            = Just d
        choose1 d (_:ds)        = choose1 d ds