{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell,TypeSynonymInstances  #-}

-- | This module provides the GTA framework on binary (and
-- leaf-valued) trees, such as definitions of the data structures
-- and their algebras, generators, aggregators, etc.
module GTA.Data.BinTree (LVTree (NodeLV, LeafLV), LVTreeAlgebra(LVTreeAlgebra, nodeLV, leafLV), LVTreeMapFs (LVTreeMapFs, leafLVF), BinTree(BinNode,BinLeaf), BinTreeAlgebra(BinTreeAlgebra, binNode, binLeaf), BinTreeMapFs (BinTreeMapFs, binLeafF, binNodeF), lvtrees, subtreeSelectsWithRoot, subtreeSelects, selects, assignTrans, assignTrees, count, maxsum, maxsumsolution, LVTreeSemiring, BinTreeSemiring) where

import GTA.Core
import GTA.Util.GenericSemiringStructureTemplate
import Data.List

-- leaf-valued binary tree
data LVTree a = NodeLV (LVTree a) (LVTree a)
              | LeafLV a
                deriving (Eq, Ord, Read)

-- The following definitions can be generated automatically by @genAllDecl ''LVTree@
-- They are written by hand here for writing comments.

-- algebra of LVTree
data LVTreeAlgebra b a = LVTreeAlgebra {
      nodeLV  :: a -> a -> a,
      leafLV :: b -> a
    }

-- a set of functions for 'map'
data LVTreeMapFs b b' = LVTreeMapFs {
      leafLVF :: b -> b'
    }

-- type parameters are algebra, free algebra, and functions for 'map'
instance GenericSemiringStructure (LVTreeAlgebra b) (LVTree b) (LVTreeMapFs b) where
  freeAlgebra = LVTreeAlgebra {..} where
      nodeLV  = NodeLV
      leafLV  = LeafLV
  pairAlgebra lvta1 lvta2 = LVTreeAlgebra {..} where
      nodeLV (l1, l2) (r1, r2) = (nodeLV1 l1 r1, nodeLV2 l2 r2)
      leafLV a                 = (leafLV1 a, leafLV2 a)
      (nodeLV1, leafLV1) = let LVTreeAlgebra {..} = lvta1 in (nodeLV, leafLV)
      (nodeLV2, leafLV2) = let LVTreeAlgebra {..} = lvta2 in (nodeLV, leafLV)
  makeAlgebra (CommutativeMonoid {..}) lvta frec fsingle = LVTreeAlgebra {..} where  
      nodeLV l r = foldr oplus identity [fsingle (nodeLV' l' r') | l' <- frec l, r' <- frec r]
      leafLV a   = fsingle (leafLV' a)
      (nodeLV', leafLV') = let LVTreeAlgebra {..} = lvta in (nodeLV, leafLV)
  foldingAlgebra op iop (LVTreeMapFs {..}) = LVTreeAlgebra {..} where
      nodeLV l r = l `op` r
      leafLV a   = leafLVF a
  hom (LVTreeAlgebra {..}) = h where
      h (NodeLV l r) = nodeLV (h l) (h r)
      h (LeafLV a)  = leafLV a

-- renaming
type LVTreeSemiring a s = GenericSemiring (LVTreeAlgebra a) s

{-
-- generalized version of the matrix chain DP algorithm O(n^3)
generates a bag of trees from a list.
The left-to-right traversal of each tree is equivalent to the input list.
-}
lvtrees :: [a] -> LVTreeSemiring a s -> s
lvtrees x bts = head (head (lvtrees' x bts))

lvtrees' :: [a] -> LVTreeSemiring a s -> [[s]]
lvtrees' x (GenericSemiring{..}) = 
    let CommutativeMonoid {..} = monoid
        LVTreeAlgebra {..} = algebra
        ls = map f x
        bigOplus = foldr oplus identity
        f a = leafLV a 
        f' l r = [nodeLV l r]
        n = length x
        merge ts k = 
            let vs = transpose (map (\(i, y) -> drop i y) (zip [1..k] ts))
                hs = map reverse (transpose ts)
                ns = zipWith mrg hs vs
            in ns:ts
        mrg h v = bigOplus (concat (zipWith f' h v))
    in foldl merge [ls] [1..(n-1)]


-- binary tree with different types for nodes and leaves
data BinTree n l = BinNode n (BinTree n l) (BinTree n l)
                 | BinLeaf l
             deriving (Eq, Ord, Read)


-- The following definitions can be generated automatically by @genAllDecl ''BinTree@
-- They are written by hand here for writing comments.

-- algebra of BinTree
data BinTreeAlgebra n l a = BinTreeAlgebra {
      binNode :: n -> a -> a -> a,
      binLeaf :: l -> a
    }

-- a set of functions for 'map'
data BinTreeMapFs n l b' = BinTreeMapFs {
      binNodeF :: n -> b',
      binLeafF :: l -> b'
    }

-- type parameters are algebra, free algebra, and functions for 'map'
instance GenericSemiringStructure (BinTreeAlgebra n l) (BinTree n l) (BinTreeMapFs n l) where
  freeAlgebra = BinTreeAlgebra {..} where
      binNode = BinNode
      binLeaf = BinLeaf
  pairAlgebra lvta1 lvta2 = BinTreeAlgebra {..} where
      binNode a (l1, l2) (r1, r2) = (binNode1 a l1 r1, binNode2 a l2 r2)
      binLeaf a                   = (binLeaf1 a, binLeaf2 a)
      (binNode1, binLeaf1) = let BinTreeAlgebra {..} = lvta1 in (binNode, binLeaf)
      (binNode2, binLeaf2) = let BinTreeAlgebra {..} = lvta2 in (binNode, binLeaf)
  makeAlgebra (CommutativeMonoid {..}) lvta frec fsingle = BinTreeAlgebra {..} where  
      binNode a l r = foldr oplus identity [fsingle (binNode' a l' r') | l' <- frec l, r' <- frec r]
      binLeaf a     = fsingle (binLeaf' a)
      (binNode', binLeaf') = let BinTreeAlgebra {..} = lvta in (binNode, binLeaf)
  foldingAlgebra op iop (BinTreeMapFs {..}) = BinTreeAlgebra {..} where
      binNode a l r = binNodeF a `op` l `op` r
      binLeaf a     = binLeafF a
  hom (BinTreeAlgebra {..}) = h where
      h (BinNode a l r) = binNode a (h l) (h r)
      h (BinLeaf a)     = binLeaf a


-- renaming
type BinTreeSemiring n l a = GenericSemiring (BinTreeAlgebra n l) a


-- BinTree-semiring for counting
count :: Num a => BinTreeSemiring n l a
count = sumproductBy (BinTreeMapFs {binLeafF = const 1, binNodeF = const 1})

-- shotcuts to maxsum of marked trees
markedT :: forall a. Num a =>
                          BinTreeMapFs (Bool, a) (Bool, a) (AddIdentity a)
markedT = BinTreeMapFs {binNodeF=f, binLeafF=f}
  where f (m,a) = AddIdentity (if m then  a else 0)

maxsum :: (Num a, Ord a) => BinTreeSemiring (Bool, a) (Bool, a) (AddIdentity a)
maxsum = maxsumBy markedT

maxsumsolution :: (Num a, Ord a) => BinTreeSemiring (Bool, a) (Bool, a) (AddIdentity a, Bag (BinTree (Bool, a) (Bool, a)))
maxsumsolution = maxsumsolutionBy markedT

-- predicate for rooted suBinTrees (i.e., those including the original root)
data RtStClass = Rtd | Emp | NG deriving (Show, Eq, Ord, Read)

rtst :: forall t t1. BinTreeAlgebra (Bool, t1) (Bool, t) RtStClass
rtst = BinTreeAlgebra {..}
  where
  binNode (True, _) l r = case (l, r) of
                            (Rtd, Rtd) -> Rtd
                            (Rtd, Emp) -> Rtd
                            (Rtd, NG)  -> NG
                            (Emp, Rtd) -> Rtd
                            (Emp, Emp) -> Rtd
                            (Emp, NG)  -> NG
                            (NG , Rtd) -> NG
                            (NG , Emp) -> NG
                            (NG , NG)  -> NG
  binNode (False, _) l r =  case (l, r) of
                              (Rtd, Rtd) -> NG
                              (Rtd, Emp) -> NG
                              (Rtd, NG)  -> NG
                              (Emp, Rtd) -> NG
                              (Emp, Emp) -> Emp
                              (Emp, NG)  -> NG
                              (NG , Rtd) -> NG
                              (NG , Emp) -> NG
                              (NG , NG)  -> NG
  binLeaf (m, _) = if m then Rtd else Emp


-- predicate for all suBinTrees
data StClass = RtdST  -- suBinTree including the root
             | IsoST  -- isolated suBinTree
             | Empty  -- empty
             | Other  -- other NGs
               deriving (Show, Eq, Ord, Read)

st :: forall t t1. BinTreeAlgebra (Bool, t1) (Bool, t) StClass
st = BinTreeAlgebra {..} where
  binNode (True, _) l r =  case (l, r) of
                             (RtdST, RtdST) -> RtdST
                             (RtdST, IsoST) -> Other
                             (RtdST, Empty) -> RtdST
                             (RtdST, Other) -> Other
                             (IsoST, RtdST) -> Other
                             (IsoST, IsoST) -> Other
                             (IsoST, Empty) -> Other
                             (IsoST, Other) -> Other
                             (Empty, RtdST) -> RtdST
                             (Empty, IsoST) -> Other
                             (Empty, Empty) -> RtdST
                             (Empty, Other) -> Other
                             (Other, RtdST) -> Other
                             (Other, IsoST) -> Other
                             (Other, Empty) -> Other
                             (Other, Other) -> Other
  binNode (False, _) l r =  case (l, r) of
                              (RtdST, RtdST) -> Other
                              (RtdST, IsoST) -> Other
                              (RtdST, Empty) -> IsoST
                              (RtdST, Other) -> Other
                              (IsoST, RtdST) -> Other
                              (IsoST, IsoST) -> Other
                              (IsoST, Empty) -> Other
                              (IsoST, Other) -> Other
                              (Empty, RtdST) -> IsoST
                              (Empty, IsoST) -> IsoST
                              (Empty, Empty) -> Empty
                              (Empty, Other) -> Other
                              (Other, RtdST) -> Other
                              (Other, IsoST) -> Other
                              (Other, Empty) -> Other
                              (Other, Other) -> Other
  binLeaf (m, _) = if m then RtdST else Empty



{-
  This is BinTreeSemiring-polymorphic LVTreeSemiring.
  We can use this kind of A-semirnig-polymorphic B-semiring to
   change the intermediate data structure from B to A
-}
assignTrans :: [b] -> [c] -> BinTreeSemiring c (b, a) s -> LVTreeSemiring a s
assignTrans msl msn bts = GenericSemiring {monoid=monoid'',algebra=algebra''} 
    where
      (monoid'', algebra') = let GenericSemiring {..} = bts 
                             in (monoid, algebra)
      BinTreeAlgebra {..} = algebra'
      CommutativeMonoid {..} = monoid''
      bigOplus = foldr oplus identity
      algebra'' = LVTreeAlgebra {..} where
          nodeLV l r = bigOplus [binNode m l r | m <- msn]
          leafLV a = bigOplus [binLeaf (m, a) | m <- msl]


---generators
{-
generating a bag of trees from a list.
The left-to-right traversal of each tree is equivalent to the input list, 
ignoring the assigned marks.
-}
assignTrees :: [b] -> [c] -> [a] -> BinTreeSemiring c (b, a) s -> s
assignTrees msl msn x = lvtrees x >=< assignTrans msl msn

-- polymorphic generator for all selections
selects :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
selects t bts = selects' t
  where
    BinTreeAlgebra {..} = algebra bts
    CommutativeMonoid {..} = monoid bts
    selects' (BinNode a l r) = 
      let l' = selects' l; r' = selects' r 
      in binNode (True, a) l' r' `oplus` binNode (False, a) l' r'
    selects' (BinLeaf a) = (binLeaf (True, a)) `oplus` (binLeaf (False, a))  


subtreeSelectsWithRoot :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
subtreeSelectsWithRoot t = selects t >== (/=NG)<.>rtst
subtreeSelects :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
subtreeSelects t = selects t >== (/=Other)<.>st