{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell  #-}

module GTA.Data.JoinList (JoinList(Times, Single, Nil), JoinListAlgebra(JoinListAlgebra), times, single, nil, joinize, dejoinize, segs, inits, tails, subs, assigns, paths, assignsBy, mapJ, count, maxsum, maxsumsolution, maxsumWith, maxsumKWith, maxsumsolutionXKWith, maxsumsolutionXWith, maxsumsolutionWith, maxsumsolutionKWith, maxprodWith, maxprodKWith, maxprodsolutionXKWith, maxprodsolutionXWith, maxprodsolutionWith, maxprodsolutionKWith, segsP, initsP, tailsP, subsP, assignsP, assignsByP, crossConcat, bagOfSingleton, emptyBag, bagOfNil, bagUnion, Semiring) where


import GTA.Core
import GTA.Util.GenericSemiringStructureTemplate
import GTA.Data.BinTree (BinTree (..))
import Control.Parallel
import Control.DeepSeq
    
-- join list = associative binary tree
data JoinList a = Times (JoinList a) (JoinList a)
                | Single a
                | Nil
--             deriving (Show, Eq, Ord, Read)

-- to use the GTA framework
genAllDecl ''JoinList

instance (NFData a) => (NFData (JoinList a)) where
  rnf (x `Times` y) = rnf x `seq` rnf y
  rnf (Single a) = rnf a
  rnf Nil = ()

-- stupid joinize function
joinize :: forall a. [a] -> JoinList a
joinize [] = Nil
joinize [a] = Single a
joinize x = let (x1,x2) = splitAt d x
                n = length x
                d = (n `div` 2)
            in Times (joinize x1) (joinize x2)

-- stupid dejoinize function
dejoinize :: forall a. JoinList a -> [a]
dejoinize (Times x1 x2) = dejoinize x1 ++ dejoinize x2
dejoinize (Single a) = [a]
dejoinize (Nil) = []

instance Show a => Show (JoinList a) where
    showsPrec d x = showsPrec d (dejoinize x)

instance Read a => Read (JoinList a) where
    readsPrec d x = map (\(y, s)->(joinize y, s)) (readsPrec d x)

instance Eq a => Eq (JoinList a) where
    (==) x y = dejoinize x == dejoinize y

instance Ord a => Ord (JoinList a) where
    compare x y = compare (dejoinize x) (dejoinize y)



-- renaming
type Semiring a s= GenericSemiring (JoinListAlgebra a) s

segs :: [a] -> Semiring a s -> s
segs = segsJ.joinize
inits :: [a] -> Semiring a s -> s
inits = initsJ.joinize
tails :: [a] -> Semiring a s -> s
tails = tailsJ.joinize
subs :: [a] -> Semiring a s -> s
subs = subsJ.joinize
assigns :: [m] -> [a] -> Semiring (m, a) s -> s
assigns ms = assignsJ ms.joinize
assignsBy :: (a -> [m]) -> [a] -> Semiring (m, a) s -> s
assignsBy f = assignsByJ f.joinize

segsJ :: JoinList a -> Semiring a s -> s
segsJ x (GenericSemiring {..}) = 
    let (s, _, _, _) = segs' x
    in s `oplus` nil 
    where segs' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (s1, i1, t1, a1) = x1
                  (s2, i2, t2, a2) = x2
              in ((s1 `oplus` s2) `oplus` (t1 `times` i2), i1 `oplus` (a1 `times` i2), (t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa, sa, sa)
          nil' = (identity, identity, identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
initsJ :: JoinList a -> Semiring a s -> s
initsJ x (GenericSemiring {..}) = 
    let (i, _) = inits' x
    in nil `oplus` i
    where inits' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (i1, a1) = x1
                  (i2, a2) = x2
              in (i1 `oplus` (a1 `times` i2), a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

tailsJ :: JoinList a -> Semiring a s -> s
tailsJ x (GenericSemiring {..}) = 
    let (t, _) = tails' x
    in t `oplus` nil
    where tails' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (t1, a1) = x1
                  (t2, a2) = x2
              in ((t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

subsJ :: JoinList a -> Semiring a s -> s
subsJ x (GenericSemiring {..}) = subs' x
    where subs' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = single a `oplus` nil
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
assignsJ :: [m] -> JoinList a -> Semiring (m,a) s -> s
assignsJ ms x (GenericSemiring {..}) = assigns' x
    where assigns' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- ms]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

assignsByJ :: (a -> [m]) -> JoinList a -> Semiring (m,a) s -> s
assignsByJ f x (GenericSemiring {..}) = assigns' x
    where assigns' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- f a]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{- this generates lists from a tree, while CYK geenerates trees from a list -}
paths :: BinTree a a -> Semiring a s -> s
paths x (GenericSemiring {..}) = paths' x
    where paths' (BinNode a l r) = single a `times` (paths' l `oplus` paths' r)
          paths' (BinLeaf a) = single a
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

-- useful function to map
mapJ :: forall b a. (b -> a) -> JoinListMapFs b a
mapJ f = JoinListMapFs {..} where singleF = f

-- JoinList-semiring for counting
count :: Num a => Semiring b a
count = sumproductBy (JoinListMapFs {singleF = const 1})


{- simplified aggregators -}

maxsum :: (Ord a, Num a) => Semiring a (AddIdentity a)
maxsum = maxsumBy (JoinListMapFs {singleF = addIdentity})

maxsumsolution :: (Ord a, Num a) => Semiring a (AddIdentity a, Bag (JoinList a))
maxsumsolution = maxsumsolutionBy (JoinListMapFs {singleF = addIdentity})

maxsumWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxsumWith f = maxsumBy (mapJ (addIdentity.f))

maxsumKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxsumKWith k f = maxsumKBy k (mapJ (addIdentity.f))

maxsumsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxsumsolutionXKWith s k f = maxsumsolutionXKBy s k (mapJ (addIdentity.f)) 

maxsumsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxsumsolutionXWith s f = maxsumsolutionXBy s (mapJ (addIdentity.f))

maxsumsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxsumsolutionWith f = maxsumsolutionBy (mapJ (addIdentity.f))

maxsumsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxsumsolutionKWith k f = maxsumsolutionKBy k (mapJ (addIdentity.f))

maxprodWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxprodWith f = maxprodBy (mapJ (addIdentity.f)) 

maxprodKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxprodKWith k f = maxprodKBy k (mapJ (addIdentity.f))

maxprodsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxprodsolutionXKWith s k f = maxprodsolutionXKBy s k (mapJ (addIdentity.f))
maxprodsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxprodsolutionXWith s f = maxprodsolutionXBy s (mapJ (addIdentity.f))

maxprodsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxprodsolutionWith f = maxprodsolutionBy (mapJ (addIdentity.f))

maxprodsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxprodsolutionKWith k f = maxprodsolutionKBy k (mapJ (addIdentity.f))


--- parallel generators

segsP :: (NFData s) => [a] -> Semiring a s -> s
segsP = segsJP.joinize

segsJP :: (NFData s) => JoinList a -> Semiring a s -> s
segsJP x (GenericSemiring {..}) = 
    let (s, _, _, _) = segs' x
    in s `oplus` nil 
    where segs' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (s1, i1, t1, a1) = x1
                  (s2, i2, t2, a2) = x2
              in ((s1 `oplus` s2) `oplus` (t1 `times` i2), i1 `oplus` (a1 `times` i2), (t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa, sa, sa)
          nil' = (identity, identity, identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          

initsP :: (NFData s) => [a] -> Semiring a s -> s
initsP = initsJP.joinize

initsJP :: (NFData s) => JoinList a -> Semiring a s -> s
initsJP x (GenericSemiring {..}) = 
    let (i, _) = inits' x
    in nil `oplus` i
    where inits' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (i1, a1) = x1
                  (i2, a2) = x2
              in (i1 `oplus` (a1 `times` i2), a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

tailsP :: (NFData s) => [a] -> Semiring a s -> s
tailsP = tailsJP.joinize

tailsJP :: (NFData s) => JoinList a -> Semiring a s -> s
tailsJP x (GenericSemiring {..}) = 
    let (t, _) = tails' x
    in t `oplus` nil
    where tails' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (t1, a1) = x1
                  (t2, a2) = x2
              in ((t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

subsP :: (NFData s) => [a] -> Semiring a s -> s
subsP = subsJP.joinize

subsJP :: (NFData s) => JoinList a -> Semiring a s -> s
subsJP x (GenericSemiring {..}) = subs' x
    where subs' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = single a `oplus` nil
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
assignsP :: (NFData s) => [m] -> [a] -> Semiring (m, a) s -> s
assignsP ms = assignsJP ms.joinize
assignsJP :: (NFData s) => [m] -> JoinList a -> Semiring (m,a) s -> s
assignsJP  ms x (GenericSemiring {..}) = assigns' x
    where assigns' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- ms]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

assignsByP :: (NFData s) => (a -> [m]) -> [a] -> Semiring (m, a) s -> s
assignsByP f = assignsByJP f.joinize
assignsByJP :: (NFData s) => (a -> [m]) -> JoinList a -> Semiring (m,a) s -> s
assignsByJP f x (GenericSemiring {..}) = assigns' x
    where assigns' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- f a]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid



parallelJoinListHom :: forall t a. (NFData a) => JoinListAlgebra t a -> JoinList t -> a
parallelJoinListHom (JoinListAlgebra {..}) = h (6::Int)  --at most 64 parallel
    where h n (x1 `Times` x2) = if n > 0 then p1 `par` (p2 `pseq` (p1 `times` p2)) else p1 `times` p2
              where p1 = h (n-1) x1
                    p2 = h (n-1) x2
          h _ (Single a) = single a
          h _ Nil = nil

--- useful functions to design generators: constructors of bags of lists
crossConcat :: Bag (JoinList a) -> Bag (JoinList a) -> Bag (JoinList a)
crossConcat = times (algebra freeSemiring)

bagOfSingleton :: a -> Bag (JoinList a)
bagOfSingleton = single (algebra freeSemiring)

bagOfNil :: Bag (JoinList a)
bagOfNil =  nil (algebra freeSemiring)

emptyBag :: Bag (JoinList a)
emptyBag = let GenericSemiring{..} = freeSemiring :: GenericSemiring (JoinListAlgebra a) (Bag (JoinList a))
           in identity monoid 

bagUnion :: Bag (JoinList a) -> Bag (JoinList a) -> Bag (JoinList a)
bagUnion = let GenericSemiring{..} = freeSemiring :: GenericSemiring (JoinListAlgebra a) (Bag (JoinList a))
           in oplus monoid