{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell  #-}

module GTA.Data.JoinList (JoinList(Times, Single, Nil), JoinListAlgebra(JoinListAlgebra), times, single, nil, joinize, dejoinize, segs, inits, tails, subs, assigns, paths, mapJ, count, maxsum, maxsumsolution, maxsumWith, maxsumKWith, maxsumsolutionXKWith, maxsumsolutionXWith, maxsumsolutionWith, maxsumsolutionKWith, maxprodWith, maxprodKWith, maxprodsolutionXKWith, maxprodsolutionXWith, maxprodsolutionWith, maxprodsolutionKWith, segsP, initsP, tailsP, subsP, assignsP) where

import GTA.Core
import GTA.Util.GenericSemiringStructureTemplate
import GTA.Data.BinTree (BinTree (..))
import Control.Parallel

{- example of the usual semirings -}

-- join list = associative binary tree
data JoinList a = Times (JoinList a) (JoinList a)
                | Single a
                | Nil
--             deriving (Show, Eq, Ord, Read)

joinize :: forall a. [a] -> JoinList a
joinize [] = Nil
joinize [a] = Single a
joinize x = let (x1,x2) = splitAt d x
                n = length x
                d = (n `div` 2)
            in Times (joinize x1) (joinize x2)

dejoinize :: forall a. JoinList a -> [a]
dejoinize (Times x1 x2) = dejoinize x1 ++ dejoinize x2
dejoinize (Single a) = [a]
dejoinize (Nil) = []

instance Show a => Show (JoinList a) where
    showsPrec d x = showsPrec d (dejoinize x)

instance Read a => Read (JoinList a) where
    readsPrec d x = map (\(x, s)->(joinize x, s)) (readsPrec d x)

instance Eq a => Eq (JoinList a) where
    (==) x y = dejoinize x == dejoinize y

instance Ord a => Ord (JoinList a) where
    compare x y = compare (dejoinize x) (dejoinize y)


-- to use the GTA framework
genAllDecl ''JoinList

-- renaming
type Semiring a s= GenericSemiring (JoinListAlgebra a) s

sequentialJoinListHom :: forall t a. JoinListAlgebra t a -> JoinList t -> a
sequentialJoinListHom = hom
segs :: [a] -> Semiring a s -> s
segs = segsJ sequentialJoinListHom.joinize
inits :: [a] -> Semiring a s -> s
inits = initsJ sequentialJoinListHom.joinize
tails :: [a] -> Semiring a s -> s
tails = tailsJ sequentialJoinListHom.joinize
subs :: [a] -> Semiring a s -> s
subs = subsJ sequentialJoinListHom.joinize
assigns :: [m] -> [a] -> Semiring (m, a) s -> s
assigns ms = assignsJ sequentialJoinListHom ms.joinize

segsJ :: (forall b s.JoinListAlgebra b s -> JoinList b -> s) -> JoinList a -> Semiring a s -> s
segsJ h x (GenericSemiring {..}) = 
    let (s, _, _, _) = segs' x
    in s `oplus` nil 
    where segs' = h (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (s1, i1, t1, a1) = x1
                  (s2, i2, t2, a2) = x2
              in ((s1 `oplus` s2) `oplus` (t1 `times` i2), i1 `oplus` (a1 `times` i2), (t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa, sa, sa)
          nil' = (identity, identity, identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
initsJ :: (forall b s.JoinListAlgebra b s -> JoinList b -> s) -> JoinList a -> Semiring a s -> s
initsJ h x (GenericSemiring {..}) = 
    let (i, _) = inits' x
    in nil `oplus` i
    where inits' = h (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (i1, a1) = x1
                  (i2, a2) = x2
              in (i1 `oplus` (a1 `times` i2), a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

tailsJ :: (forall b s.JoinListAlgebra b s -> JoinList b -> s) -> JoinList a -> Semiring a s -> s
tailsJ h x (GenericSemiring {..}) = 
    let (t, _) = tails' x
    in t `oplus` nil
    where tails' = h (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (t1, a1) = x1
                  (t2, a2) = x2
              in ((t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

subsJ :: (forall b s.JoinListAlgebra b s -> JoinList b -> s) -> JoinList a -> Semiring a s -> s
subsJ h x (GenericSemiring {..}) = subs' x
    where subs' = h (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = single a `oplus` nil
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
assignsJ :: (forall b s.JoinListAlgebra b s -> JoinList b -> s) -> [m] -> JoinList a -> Semiring (m,a) s -> s
assignsJ h ms x (GenericSemiring {..}) = assigns' x
    where assigns' = h (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- ms]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{- this generates lists from a tree, while CYK geenerates trees from a list -}
paths :: BinTree a a -> Semiring a s -> s
paths x (GenericSemiring {..}) = paths' x
    where paths' (BinNode a l r) = single a `times` (paths' l `oplus` paths' r)
          paths' (BinLeaf a) = single a
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

-- useful function to map
mapJ :: forall b a. (b -> a) -> JoinListMapFs b a
mapJ f = JoinListMapFs {..} where singleF = f

-- JoinList-semiring for counting
count :: Num a => Semiring b a
count = sumproductBy (JoinListMapFs {singleF = const 1})


{- simplified aggregators -}

maxsum :: (Ord a, Num a) => Semiring a (AddIdentity a)
maxsum = maxsumBy (JoinListMapFs {singleF = addIdentity})

maxsumsolution :: (Ord a, Num a) => Semiring a (AddIdentity a, Bag (JoinList a))
maxsumsolution = maxsumsolutionBy (JoinListMapFs {singleF = addIdentity})

maxsumWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxsumWith f = maxsumBy (mapJ (addIdentity.f))

maxsumKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxsumKWith k f = maxsumKBy k (mapJ (addIdentity.f))

maxsumsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxsumsolutionXKWith s k f = maxsumsolutionXKBy s k (mapJ (addIdentity.f)) 

maxsumsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxsumsolutionXWith s f = maxsumsolutionXBy s (mapJ (addIdentity.f))

maxsumsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxsumsolutionWith f = maxsumsolutionBy (mapJ (addIdentity.f))

maxsumsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxsumsolutionKWith k f = maxsumsolutionKBy k (mapJ (addIdentity.f))

maxprodWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxprodWith f = maxprodBy (mapJ (addIdentity.f)) 

maxprodKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxprodKWith k f = maxprodKBy k (mapJ (addIdentity.f))

maxprodsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxprodsolutionXKWith s k f = maxprodsolutionXKBy s k (mapJ (addIdentity.f))
maxprodsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxprodsolutionXWith s f = maxprodsolutionXBy s (mapJ (addIdentity.f))

maxprodsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxprodsolutionWith f = maxprodsolutionBy (mapJ (addIdentity.f))

maxprodsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxprodsolutionKWith k f = maxprodsolutionKBy k (mapJ (addIdentity.f))


--- parallel generators

segsP :: [a] -> Semiring a s -> s
segsP = segsJ parallelJoinListHom.joinize
initsP :: [a] -> Semiring a s -> s
initsP = initsJ parallelJoinListHom.joinize
tailsP :: [a] -> Semiring a s -> s
tailsP = tailsJ parallelJoinListHom.joinize
subsP :: [a] -> Semiring a s -> s
subsP = subsJ parallelJoinListHom.joinize
assignsP :: [m] -> [a] -> Semiring (m, a) s -> s
assignsP ms = assignsJ parallelJoinListHom ms.joinize

parallelJoinListHom :: forall t a. JoinListAlgebra t a -> JoinList t -> a
parallelJoinListHom (JoinListAlgebra {..}) = h (6::Int)  --at most 64 parallel
    where h n (x1 `Times` x2) = if n > 0 then p1 `par` (p2 `pseq` (p1 `times` p2)) else p1 `times` p2
              where p1 = h (n-1) x1
                    p2 = h (n-1) x2
          h _ (Single a) = single a
          h _ Nil = nil