{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell,TypeSynonymInstances  #-}

module GTA.Data.BinTree (LVTree (NodeLV, LeafLV), LVTreeAlgebra(LVTreeAlgebra), nodeLV, leafLV, LVTreeMapFs (LVTreeMapFs), leafLVF, BinTree(BinNode,BinLeaf), BinTreeAlgebra(BinTreeAlgebra),binNode,binLeaf, BinTreeMapFs (BinTreeMapFs), binLeafF, binNodeF, lvtrees, subtreeSelectsWithRoot, subtreeSelects, selects, assignTrans, assignTrees, count, maxsum, maxsumsolution) where

import GTA.Core
import GTA.Util.GenericSemiringStructureTemplate
import Data.List

-- leaf-valued binary tree
data LVTree a = NodeLV (LVTree a) (LVTree a)
              | LeafLV a
                deriving (Eq, Ord, Read)

-- automatic generation of things necessary for GTA framework
genAllDecl ''LVTree

-- renaming
type LVTreeSemiring a s = GenericSemiring (LVTreeAlgebra a) s

{-
-- generalized version of the matrix chain DP algorithm O(n^3)
generates a bag of trees from a list.
The left-to-right traversal of each tree is equivalent to the input list.
-}
lvtrees :: [a] -> LVTreeSemiring a s -> s
lvtrees x bts = head (head (lvtrees' x bts))

lvtrees' :: [a] -> LVTreeSemiring a s -> [[s]]
lvtrees' x (GenericSemiring{..}) = 
    let CommutativeMonoid {..} = monoid
        LVTreeAlgebra {..} = algebra
        ls = map f x
        bigOplus = foldr oplus identity
        f a = leafLV a 
        f' l r = [nodeLV l r]
        n = length x
        merge ts k = 
            let vs = transpose (map (\(i, y) -> drop i y) (zip [1..k] ts))
                hs = map reverse (transpose ts)
                ns = zipWith mrg hs vs
            in ns:ts
        mrg h v = bigOplus (concat (zipWith f' h v))
    in foldl merge [ls] [1..(n-1)]


-- binary tree with different types for nodes and leaves
data BinTree n l = BinNode n (BinTree n l) (BinTree n l)
                 | BinLeaf l
             deriving (Eq, Ord, Read)


genAllDecl ''BinTree

{-
{- this algebra can be generated automatically from BinTree -}
genAlgebraDecl ''BinTree
{-
data BinTreeAlgebra n l a = 
  BinTreeAlgebra {
    binNode :: n -> a -> a -> a,
    binLeaf :: l -> a
  }
-}
genMapFunctionsDecl ''BinTree
-- -- maps to a coherent data
-- data BinTreeMapFs n l a = BinTreeMapFs {
--         binNodeF :: (n -> a),
--         binLeafF :: (l -> a)
--       }

{- this instance can be generated automatically from BinTree -}
genInstanceDecl ''BinTree

-- the generic semiring structure of BinTreeALgebra n l
-- instance GenericSemiringStructure (BinTreeAlgebra n l) (BinTree n l) (BinTreeMapFs n l) where
--   freeAlgebra = BinTreeAlgebra {..} where
--     binNode = BinNode
--     binLeaf = BinLeaf
--   hom (BinTreeAlgebra {..}) = h
--     where
--       h (BinNode a l r) = binNode a (h l) (h r)
--       h (BinLeaf a) = binLeaf a
--   pairAlgebra bt1 bt2 = BinTreeAlgebra {..} 
--     where
--       binNode a (l1, l2) (r1, r2) = (binNode1 a l1 r1, binNode2 a l2 r2)
--       binLeaf a = (binLeaf1 a, binLeaf2 a)
--       (binLeaf1, binNode1) = let BinTreeAlgebra {..} = bt1 in (binLeaf, binNode)
--       (binLeaf2, binNode2) = let BinTreeAlgebra {..} = bt2 in (binLeaf, binNode)
--   makeAlgebra (CommutativeMonoid {..}) bt frec fsingle = BinTreeAlgebra {..}
--     where  
--     binNode a l r = foldr oplus identity [fsingle (binNode' a l' r') | l' <- frec l, r' <- frec r]
--     binLeaf a = fsingle (binLeaf' a)
--     (binLeaf', binNode') = let BinTreeAlgebra {..} = bt in (binLeaf, binNode)
--   foldingAlgebra op iop (BinTreeMapFs {binNodeF=(binNodeF1),binLeafF=(binLeafF1)}) = BinTreeAlgebra {..}
--     where
--     binNode a l r = binNodeF1 a `op` l `op` r
--     binLeaf a = binLeafF1 a

-}

-- renaming
type BinTreeSemiring n l a = GenericSemiring (BinTreeAlgebra n l) a


-- BinTree-semiring for counting
count :: Num a => BinTreeSemiring n l a
count = sumproductBy (BinTreeMapFs {binLeafF = const 1, binNodeF = const 1})

-- shotcuts to maxsum of marked trees
markedT :: forall a. Num a =>
                          BinTreeMapFs (Bool, a) (Bool, a) (AddIdentity a)
markedT = BinTreeMapFs {binNodeF=f, binLeafF=f}
  where f (m,a) = AddIdentity (if m then  a else 0)

maxsum :: (Num a, Ord a) => BinTreeSemiring (Bool, a) (Bool, a) (AddIdentity a)
maxsum = maxsumBy markedT

maxsumsolution :: (Num a, Ord a) => BinTreeSemiring (Bool, a) (Bool, a) (AddIdentity a, Bag (BinTree (Bool, a) (Bool, a)))
maxsumsolution = maxsumsolutionBy markedT

-- predicate for rooted suBinTrees (i.e., those including the original root)
data RtStClass = Rtd | Emp | NG deriving (Show, Eq, Ord, Read)

rtst :: forall t t1. BinTreeAlgebra (Bool, t1) (Bool, t) RtStClass
rtst = BinTreeAlgebra {..}
  where
  binNode (True, _) l r = case (l, r) of
                            (Rtd, Rtd) -> Rtd
                            (Rtd, Emp) -> Rtd
                            (Rtd, NG)  -> NG
                            (Emp, Rtd) -> Rtd
                            (Emp, Emp) -> Rtd
                            (Emp, NG)  -> NG
                            (NG , Rtd) -> NG
                            (NG , Emp) -> NG
                            (NG , NG)  -> NG
  binNode (False, _) l r =  case (l, r) of
                              (Rtd, Rtd) -> NG
                              (Rtd, Emp) -> NG
                              (Rtd, NG)  -> NG
                              (Emp, Rtd) -> NG
                              (Emp, Emp) -> Emp
                              (Emp, NG)  -> NG
                              (NG , Rtd) -> NG
                              (NG , Emp) -> NG
                              (NG , NG)  -> NG
  binLeaf (m, _) = if m then Rtd else Emp


-- predicate for all suBinTrees
data StClass = RtdST  -- suBinTree including the root
             | IsoST  -- isolated suBinTree
             | Empty  -- empty
             | Other  -- other NGs
               deriving (Show, Eq, Ord, Read)

st :: forall t t1. BinTreeAlgebra (Bool, t1) (Bool, t) StClass
st = BinTreeAlgebra {..} where
  binNode (True, _) l r =  case (l, r) of
                             (RtdST, RtdST) -> RtdST
                             (RtdST, IsoST) -> Other
                             (RtdST, Empty) -> RtdST
                             (RtdST, Other) -> Other
                             (IsoST, RtdST) -> Other
                             (IsoST, IsoST) -> Other
                             (IsoST, Empty) -> Other
                             (IsoST, Other) -> Other
                             (Empty, RtdST) -> RtdST
                             (Empty, IsoST) -> Other
                             (Empty, Empty) -> RtdST
                             (Empty, Other) -> Other
                             (Other, RtdST) -> Other
                             (Other, IsoST) -> Other
                             (Other, Empty) -> Other
                             (Other, Other) -> Other
  binNode (False, _) l r =  case (l, r) of
                              (RtdST, RtdST) -> Other
                              (RtdST, IsoST) -> Other
                              (RtdST, Empty) -> IsoST
                              (RtdST, Other) -> Other
                              (IsoST, RtdST) -> Other
                              (IsoST, IsoST) -> Other
                              (IsoST, Empty) -> Other
                              (IsoST, Other) -> Other
                              (Empty, RtdST) -> IsoST
                              (Empty, IsoST) -> IsoST
                              (Empty, Empty) -> Empty
                              (Empty, Other) -> Other
                              (Other, RtdST) -> Other
                              (Other, IsoST) -> Other
                              (Other, Empty) -> Other
                              (Other, Other) -> Other
  binLeaf (m, _) = if m then RtdST else Empty



{-
  This is BinTreeSemiring-polymorphic LVTreeSemiring.
  We can use this kind of A-semirnig-polymorphic B-semiring to
   change the intermediate data structure from B to A
-}
assignTrans :: [b] -> [c] -> BinTreeSemiring c (b, a) s -> LVTreeSemiring a s
assignTrans msl msn bts = GenericSemiring {monoid=monoid'',algebra=algebra''} 
    where
      (monoid'', algebra') = let GenericSemiring {..} = bts 
                             in (monoid, algebra)
      BinTreeAlgebra {..} = algebra'
      CommutativeMonoid {..} = monoid''
      bigOplus = foldr oplus identity
      algebra'' = LVTreeAlgebra {..} where
          nodeLV l r = bigOplus [binNode m l r | m <- msn]
          leafLV a = bigOplus [binLeaf (m, a) | m <- msl]


---generators
{-
generating a bag of trees from a list.
The left-to-right traversal of each tree is equivalent to the input list, 
ignoring the assigned marks.
-}
assignTrees :: [b] -> [c] -> [a] -> BinTreeSemiring c (b, a) s -> s
assignTrees msl msn x = lvtrees x >=< assignTrans msl msn

-- polymorphic generator for all selections
selects :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
selects t bts = selects' t
  where
    BinTreeAlgebra {..} = algebra bts
    CommutativeMonoid {..} = monoid bts
    selects' (BinNode a l r) = 
      let l' = selects' l; r' = selects' r 
      in binNode (True, a) l' r' `oplus` binNode (False, a) l' r'
    selects' (BinLeaf a) = (binLeaf (True, a)) `oplus` (binLeaf (False, a))  


subtreeSelectsWithRoot :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
subtreeSelectsWithRoot t = selects t >== (/=NG)<.>rtst
subtreeSelects :: BinTree n l -> BinTreeSemiring (Bool,n) (Bool,l) a -> a
subtreeSelects t = selects t >== (/=Other)<.>st