{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell  #-}

{-| This module provides the GTA framework on join lists, such as definitions of the data structure and its algebra, parallel/serial generators, aggregators, etc.
-}
module GTA.Data.JoinList (JoinList(Times, Single, Nil), JoinListAlgebra(JoinListAlgebra, times, single, nil), joinize, dejoinize, segs, inits, tails, subs, assigns, paths, assignsBy, mapJ, count, maxsum, maxsumsolution, maxsumWith, maxsumKWith, maxsumsolutionXKWith, maxsumsolutionXWith, maxsumsolutionWith, maxsumsolutionKWith, maxprodWith, maxprodKWith, maxprodsolutionXKWith, maxprodsolutionXWith, maxprodsolutionWith, maxprodsolutionKWith, segsP, initsP, tailsP, subsP, assignsP, assignsByP, crossConcat, bagOfSingleton, emptyBag, bagOfNil, bagUnion, Semiring, prop_Associativity, prop_Identity,joinListAlgebra,JoinListMapFs(singleF)) where


import GTA.Core
import GTA.Util.GenericSemiringStructureTemplate
import GTA.Data.BinTree (BinTree (..))
import Control.Parallel
import Control.DeepSeq
    
-- join list = associative binary tree
{-|
Join lists. 

> x ++ y ==> x `Times` y
> [a]    ==> Single a
> []     ==> Nil

We assume that `Times` is associative and `Nil` is its identity:

> x `Times` (y `Times` z) == (x `Times` y) `Times` z
> x `Times` Nil == Nil `Times` x == x

 -}
data JoinList a = Times (JoinList a) (JoinList a)
                | Single a
                | Nil
--             deriving (Show, Eq, Ord, Read)

-- to use the GTA framework
-- The following definitions can be generated automatically by @genAllDecl ''JoinList@
-- They are written by hand here for writing comments.

-- algebra of JoinList
{-|  
The algebra of join lists.

We assume that `times` is associative and `nil` is its identity, inheriting those of `Times` and `Nil`:

> x `times` (y `times` z) == (x `times` y) `times` z
> x `times` nil == nil `times` x == x


This can be generated automatically by @genAllDecl ''JoinList@.
-}
data JoinListAlgebra b a = JoinListAlgebra {
      times  :: a -> a -> a,
      single :: b -> a,
      nil    :: a
    }

-- a set of functions for 'map'
{-|  
A record to hold a function to be applied to elements of a list.

This can be generated automatically by @genAllDecl ''JoinList@.
-}
data JoinListMapFs b b' = JoinListMapFs {
      singleF :: b -> b'
    }

-- type parameters are algebra, free algebra, and functions for 'map'
{-|  
Instance declaration of GTA.Data.GenericSemiringStructure for join lists. The implementation is quite straightforward.

This can be generated automatically by @genAllDecl ''JoinList@.
-}
instance GenericSemiringStructure (JoinListAlgebra b) (JoinList b) (JoinListMapFs b) where
  freeAlgebra = JoinListAlgebra {..} where
      times  = Times
      single = Single
      nil    = Nil
  pairAlgebra jla1 jla2 = JoinListAlgebra {..} where
      times (l1, l2) (r1, r2) = (times1 l1 r1, times2 l2 r2)
      single a                = (single1 a, single2 a)
      nil                     = (nil1, nil2)
      (times1, single1, nil1) = let JoinListAlgebra {..} = jla1 in (times, single, nil)
      (times2, single2, nil2) = let JoinListAlgebra {..} = jla2 in (times, single, nil)
  makeAlgebra (CommutativeMonoid {..}) jla frec fsingle = JoinListAlgebra {..} where  
      times l r = foldr oplus identity [fsingle (times' l' r') | l' <- frec l, r' <- frec r]
      single a  = fsingle (single' a)
      nil       = fsingle nil'
      (times', single', nil') = let JoinListAlgebra {..} = jla in (times, single, nil)
  foldingAlgebra op iop (JoinListMapFs {..}) = JoinListAlgebra {..} where
      times l r = l `op` r
      single a  = singleF a
      nil       = iop
  hom (JoinListAlgebra {..}) = h where
      h (Times l r) = times (h l) (h r)
      h (Single a)  = single a
      h Nil         = nil

{-| A wrapper function for record 'JoinListAlgebra' . (I needed this as a workaround of cabal's brace-eating bug.)-}
joinListAlgebra :: (a -> a -> a) -> (b -> a) -> a -> JoinListAlgebra b a
joinListAlgebra times single nil = JoinListAlgebra{..}

-- properties of JoinListAlgebra for correct parallelization
{-| Property of `times` of a JoinListAlgebra:

 > x `times` (y `times` z) == (x `times` y) `times` z

 -}
prop_Associativity :: (Eq b) => JoinListAlgebra a b -> (b,b,b) -> Bool 
prop_Associativity (JoinListAlgebra{..}) (x,y,z) 
  = x `times` (y `times` z) == (x `times` y) `times` z

{-| Property of `times` and `nil` of a JoinListAlgebra:

 > (x `times` nil == x) && (nil `times` x == x)

 -}
prop_Identity :: (Eq b) => JoinListAlgebra a b -> b -> Bool 
prop_Identity (JoinListAlgebra{..}) x
  = (x `times` nil == x) && (nil `times` x == x)

instance (NFData a) => (NFData (JoinList a)) where
  rnf (x `Times` y) = rnf x `seq` rnf y
  rnf (Single a) = rnf a
  rnf Nil = ()

-- stupid joinize function
{-| Conversion from a usual list to a join list. -}
-- This conversion is stupid. 
joinize :: forall a. [a] -> JoinList a
joinize [] = Nil
joinize [a] = Single a
joinize x = let (x1,x2) = splitAt d x
                n = length x
                d = (n `div` 2)
            in Times (joinize x1) (joinize x2)

-- stupid dejoinize function
{-| Conversion from a join list to a usual list. -}
-- This conversion is stupid. 
dejoinize :: forall a. JoinList a -> [a]
dejoinize (Times x1 x2) = dejoinize x1 ++ dejoinize x2
dejoinize (Single a) = [a]
dejoinize (Nil) = []

instance Show a => Show (JoinList a) where
    showsPrec d x = showsPrec d (dejoinize x)

instance Read a => Read (JoinList a) where
    readsPrec d x = map (\(y, s)->(joinize y, s)) (readsPrec d x)

instance Eq a => Eq (JoinList a) where
    (==) x y = dejoinize x == dejoinize y

instance Ord a => Ord (JoinList a) where
    compare x y = compare (dejoinize x) (dejoinize y)



-- renaming
{-| The usual semiring is a generic semiring of join lists:

> a `times` (b `oplus` c) == (a `times` b) `oplus` (a `times` c)
> (a `oplus` b) `times` c == (a `times` c) `oplus` (b `times` c)
> a `times` identity == identity `times` a == identity

 -}
type Semiring a s= GenericSemiring (JoinListAlgebra a) s

{-| This generates all segments (continuous subsequences) of a given list. 

For example, 

>>> segs [1,2,3] `aggregateBy` result
Bag [[1],[2],[3],[2,3],[1,2],[1,2,3],[]]

-}
segs :: [a] -> Semiring a s -> s
segs = segsJ.joinize

{-| This generates all prefixes of a given list. 

For example, 

>>> inits [1,2,3] `aggregateBy` result
Bag [[],[1],[1,2],[1,2,3]]

-}
inits :: [a] -> Semiring a s -> s
inits = initsJ.joinize

{-| This generates all suffixes of a given list. 

For example, 

>>> tails [1,2,3] `aggregateBy` result
Bag [[1,2,3],[2,3],[3],[]]

-}
tails :: [a] -> Semiring a s -> s
tails = tailsJ.joinize

{-| This generates all subsequences of a given list. 

For example, 

>>> subs [1,2,3] `aggregateBy` result
Bag [[1,2,3],[1,2],[1,3],[1],[2,3],[2],[3],[]]

-}
subs :: [a] -> Semiring a s -> s
subs = subsJ.joinize

{-| This generates all assignments of elements of the first list to elements of the second list.

For example, 

>>> assigns [True,False] [1,2,3] `aggregateBy` result
Bag [[(True,1),(True,2),(True,3)],[(True,1),(True,2),(False,3)],[(True,1),(False,2),(True,3)],[(True,1),(False,2),(False,3)],[(False,1),(True,2),(True,3)],[(False,1),(True,2),(False,3)],[(False,1),(False,2),(True,3)],[(False,1),(False,2),(False,3)]]

-}
assigns :: [m] -> [a] -> Semiring (m, a) s -> s
assigns ms = assignsJ ms.joinize

{-| This is a generalization of `assigns`: the values to be assigned is dependent of the target.

For example, 

>>> assignsBy (\a -> if odd a then [True, False] else [True]) [1,2,3] `aggregateBy` result
Bag [[(True,1),(True,2),(True,3)],[(True,1),(True,2),(False,3)],[(False,1),(True,2),(True,3)],[(False,1),(True,2),(False,3)]]

-}
assignsBy :: (a -> [m]) -> [a] -> Semiring (m, a) s -> s
assignsBy f = assignsByJ f.joinize

segsJ :: JoinList a -> Semiring a s -> s
segsJ x (GenericSemiring {..}) = 
    let (s, _, _, _) = segs' x
    in s `oplus` nil 
    where segs' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (s1, i1, t1, a1) = x1
                  (s2, i2, t2, a2) = x2
              in ((s1 `oplus` s2) `oplus` (t1 `times` i2), i1 `oplus` (a1 `times` i2), (t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa, sa, sa)
          nil' = (identity, identity, identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
initsJ :: JoinList a -> Semiring a s -> s
initsJ x (GenericSemiring {..}) = 
    let (i, _) = inits' x
    in nil `oplus` i
    where inits' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (i1, a1) = x1
                  (i2, a2) = x2
              in (i1 `oplus` (a1 `times` i2), a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

tailsJ :: JoinList a -> Semiring a s -> s
tailsJ x (GenericSemiring {..}) = 
    let (t, _) = tails' x
    in t `oplus` nil
    where tails' = hom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (t1, a1) = x1
                  (t2, a2) = x2
              in ((t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

subsJ :: JoinList a -> Semiring a s -> s
subsJ x (GenericSemiring {..}) = subs' x
    where subs' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = single a `oplus` nil
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
assignsJ :: [m] -> JoinList a -> Semiring (m,a) s -> s
assignsJ ms x (GenericSemiring {..}) = assigns' x
    where assigns' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- ms]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

assignsByJ :: (a -> [m]) -> JoinList a -> Semiring (m,a) s -> s
assignsByJ f x (GenericSemiring {..}) = assigns' x
    where assigns' = hom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- f a]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{- this generates lists from a tree, while CYK geenerates trees from a list -}
{-| This generates all paths from the root to leaves of a given binary tree.

For example, 

>>> *Main GTA.Data.BinTree> paths (BinNode 1 (BinLeaf 2) (BinNode 3 (BinLeaf 4) (BinLeaf 5))) `aggregateBy` result
Bag [[1,2],[1,3,4],[1,3,5]]

-}
paths :: BinTree a a -> Semiring a s -> s
paths x (GenericSemiring {..}) = paths' x
    where paths' (BinNode a l r) = single a `times` (paths' l `oplus` paths' r)
          paths' (BinLeaf a) = single a
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

-- useful function to map
{-| Wrapper for 'JoinListMapFs'.
-}
mapJ :: forall b a. (b -> a) -> JoinListMapFs b a
mapJ f = JoinListMapFs {..} where singleF = f

-- JoinList-semiring for counting
{-| The aggregator to count the number of items in a generated bag.
-}
count :: Num a => Semiring b a
count = sumproductBy (JoinListMapFs {singleF = const 1})


{- simplified aggregators -}
{-| The aggregator to take the maximum sum.
-}
maxsum :: (Ord a, Num a) => Semiring a (AddIdentity a)
maxsum = maxsumBy (JoinListMapFs {singleF = addIdentity})

{-| The aggregator to find items with the maximum sum.
-}
maxsumsolution :: (Ord a, Num a) => Semiring a (AddIdentity a, Bag (JoinList a))
maxsumsolution = maxsumsolutionBy (JoinListMapFs {singleF = addIdentity})

{-| The aggregator to take the maximum sum after @map f@.
-}
maxsumWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxsumWith f = maxsumBy (mapJ (addIdentity.f))

{-| The /best-k/ extension of `maxsumWith`.
-}
maxsumKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxsumKWith k f = maxsumKBy k (mapJ (addIdentity.f))

{-| The /best-k/ extension of `maxsumsolutionXWith`.
-}
maxsumsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxsumsolutionXKWith s k f = maxsumsolutionXKBy s k (mapJ (addIdentity.f)) 

{-| The tupling of maxsumsolution and a given semiring. The second component is the aggregation of the maximum items by the given semiring.
-}
maxsumsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxsumsolutionXWith s f = maxsumsolutionXBy s (mapJ (addIdentity.f))

{-| The aggregator to find items with the maximum sum after @map f@.
-}
maxsumsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxsumsolutionWith f = maxsumsolutionBy (mapJ (addIdentity.f))

{-| The /best-k/ extension of `maxsumsolutionWith`.
-}
maxsumsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxsumsolutionKWith k f = maxsumsolutionKBy k (mapJ (addIdentity.f))

{-| The aggregator to take the maximum product of /non-negative/ numbers after @map f@.
-}
maxprodWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a)
maxprodWith f = maxprodBy (mapJ (addIdentity.f)) 

{-| The /best-k/ extension of `maxprodWith`.
-}
maxprodKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b ([AddIdentity a])
maxprodKWith k f = maxprodKBy k (mapJ (addIdentity.f))

{-| The /best-k/ extension of `maxprodsolutionXWith`.
-}
maxprodsolutionXKWith :: (Ord a, Num a) =>
                       Semiring c b -> Int -> (c -> a) -> Semiring c [(AddIdentity a, b)]
maxprodsolutionXKWith s k f = maxprodsolutionXKBy s k (mapJ (addIdentity.f))

{-| The tupling of maxprodsolution and a given semiring. The second component is the aggregation of the maximum items by the given semiring.
-}
maxprodsolutionXWith :: (Ord a, Num a) =>
                       Semiring c b -> (c -> a) -> Semiring c (AddIdentity a, b)
maxprodsolutionXWith s f = maxprodsolutionXBy s (mapJ (addIdentity.f))

{-| The aggregator to find items with the maximum product. The numbers have to be /non-negative/.
-}
maxprodsolutionWith :: (Ord a, Num a) => (b -> a) -> Semiring b (AddIdentity a, Bag (JoinList b))
maxprodsolutionWith f = maxprodsolutionBy (mapJ (addIdentity.f))

{-| The /best-k/ extension of `maxprodsolutionWith`.
-}
maxprodsolutionKWith :: (Ord a, Num a) => Int -> (b -> a) -> Semiring b [(AddIdentity a, Bag (JoinList b))]
maxprodsolutionKWith k f = maxprodsolutionKBy k (mapJ (addIdentity.f))


--- parallel generators

{-| Parallel version of `segs`. -}
segsP :: (NFData s) => [a] -> Semiring a s -> s
segsP = segsJP.joinize

segsJP :: (NFData s) => JoinList a -> Semiring a s -> s
segsJP x (GenericSemiring {..}) = 
    let (s, _, _, _) = segs' x
    in s `oplus` nil 
    where segs' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (s1, i1, t1, a1) = x1
                  (s2, i2, t2, a2) = x2
              in ((s1 `oplus` s2) `oplus` (t1 `times` i2), i1 `oplus` (a1 `times` i2), (t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa, sa, sa)
          nil' = (identity, identity, identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          

{-| Parallel version of `inits`. -}
initsP :: (NFData s) => [a] -> Semiring a s -> s
initsP = initsJP.joinize

initsJP :: (NFData s) => JoinList a -> Semiring a s -> s
initsJP x (GenericSemiring {..}) = 
    let (i, _) = inits' x
    in nil `oplus` i
    where inits' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (i1, a1) = x1
                  (i2, a2) = x2
              in (i1 `oplus` (a1 `times` i2), a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{-| Parallel version of `tails`. -}
tailsP :: (NFData s) => [a] -> Semiring a s -> s
tailsP = tailsJP.joinize

tailsJP :: (NFData s) => JoinList a -> Semiring a s -> s
tailsJP x (GenericSemiring {..}) = 
    let (t, _) = tails' x
    in t `oplus` nil
    where tails' = parallelJoinListHom (JoinListAlgebra {times=times',single=single',nil=nil'})
          times' x1 x2 = 
              let (t1, a1) = x1
                  (t2, a2) = x2
              in ((t1 `times` a2) `oplus`t2, a1 `times` a2)
          single' a = let sa = single a in (sa, sa)
          nil' = (identity, nil)
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{-| Parallel version of `subs`. -}
subsP :: (NFData s) => [a] -> Semiring a s -> s
subsP = subsJP.joinize

subsJP :: (NFData s) => JoinList a -> Semiring a s -> s
subsJP x (GenericSemiring {..}) = subs' x
    where subs' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = single a `oplus` nil
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid
          
{-| Parallel version of `assigns`. -}
assignsP :: (NFData s) => [m] -> [a] -> Semiring (m, a) s -> s
assignsP ms = assignsJP ms.joinize
assignsJP :: (NFData s) => [m] -> JoinList a -> Semiring (m,a) s -> s
assignsJP  ms x (GenericSemiring {..}) = assigns' x
    where assigns' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- ms]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid

{-| Parallel version of `assignsBy`. -}
assignsByP :: (NFData s) => (a -> [m]) -> [a] -> Semiring (m, a) s -> s
assignsByP f = assignsByJP f.joinize
assignsByJP :: (NFData s) => (a -> [m]) -> JoinList a -> Semiring (m,a) s -> s
assignsByJP f x (GenericSemiring {..}) = assigns' x
    where assigns' = parallelJoinListHom (JoinListAlgebra {times=times,single=single',nil=nil})
          single' a = foldr oplus identity [single (m, a) | m <- f a]
          JoinListAlgebra {..} = algebra
          CommutativeMonoid {..} = monoid



parallelJoinListHom :: forall t a. (NFData a) => JoinListAlgebra t a -> JoinList t -> a
parallelJoinListHom (JoinListAlgebra {..}) = h (6::Int)  --at most 64 parallel
    where h n (x1 `Times` x2) = if n > 0 then p1 `par` (p2 `pseq` (p1 `times` p2)) else p1 `times` p2
              where p1 = h (n-1) x1
                    p2 = h (n-1) x2
          h _ (Single a) = single a
          h _ Nil = nil

--- useful functions to design generators: constructors of bags of lists
{-| Constructor of a bag of join lists.

For example,

>>> (bag (map joinize [[1,2], [3]])) `crossConcat` (bag (map joinize [[4,5], [6]]))
Bag [[1,2,4,5],[1,2,6],[3,4,5],[3,6]]

 -}
crossConcat :: Bag (JoinList a) -> Bag (JoinList a) -> Bag (JoinList a)
crossConcat = times (algebra freeSemiring)

{-| Constructor of a bag of join lists.

For example,

>>> bagOfSingleton 1
Bag [[1]]

 -}
bagOfSingleton :: a -> Bag (JoinList a)
bagOfSingleton = single (algebra freeSemiring)

{-| Constructor of a bag of join lists.

For example,

>>> bagOfNil
Bag [[]]

-}
bagOfNil :: Bag (JoinList a)
bagOfNil =  nil (algebra freeSemiring)

{-| Constructor of a bag of join lists.

For example,

>>> emptyBag
Bag []

-}
emptyBag :: Bag (JoinList a)
emptyBag = let GenericSemiring{..} = freeSemiring :: GenericSemiring (JoinListAlgebra a) (Bag (JoinList a))
           in identity monoid 

{-| Constructor of a bag of join lists.

For example,

>>> (bag (map joinize [[1,2], [3]])) `bagUnion` (bag (map joinize [[4,5], [6]]))
Bag [[1,2],[3],[4,5],[6]]

 -}
bagUnion :: Bag (JoinList a) -> Bag (JoinList a) -> Bag (JoinList a)
bagUnion = let GenericSemiring{..} = freeSemiring :: GenericSemiring (JoinListAlgebra a) (Bag (JoinList a))
           in oplus monoid