{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,EmptyDataDecls,RecordWildCards,TypeFamilies,TemplateHaskell  #-}

module GTA.Core (Bag(Bag), CommutativeMonoid (CommutativeMonoid), oplus, identity, GenericSemiring (GenericSemiring), monoid, algebra, GenericSemiringStructure, freeSemiring, liftedSemiring, pairSemiring, shom, hom, makeAlgebra, freeAlgebra, pairAlgebra, foldingAlgebra, bag, (>==), (>=>), (>=<), (>##), (>#>), (<.>), items, revOrd, RevOrd(RevOrd), maxsumBy, maxsumKBy, maxsumsolutionXKBy, maxsumsolutionXBy, maxsumsolutionBy, maxsumsolutionKBy, maxprodBy, maxprodKBy, maxprodsolutionXKBy, maxprodsolutionXBy, maxprodsolutionBy, maxprodsolutionKBy, maxMonoSumBy, maxMonoSumsolutionXBy, maxMonoSumKBy, maxMonoSumsolutionXKBy, addIdentity, AddIdentity (AddIdentity, Identity), sumproductBy, result, filterBy, aggregateBy, transformBy, ) where

import Data.List
import Data.Map (Map,empty, singleton, unionWith,assocs)


-- The bag
data Bag a = Bag [a] deriving (Show,Ord,Read)

instance (Eq a, Ord a) => Eq (Bag a) where
  (==) (Bag a) (Bag b) = sort a == sort b

items :: Bag a -> [a]
items (Bag t) = t

bag :: forall a. [a] -> Bag a
bag t = Bag t

--Bag filter
filterB :: forall a. (a -> Bool) -> Bag a -> Bag a
filterB p (Bag b) = Bag (filter p b)

data CommutativeMonoid a = CommutativeMonoid {
    oplus :: a -> a -> a,  -- commutative, associative
    identity::a           -- the identity of oplus
    }

-- bag is commutative monoid
bagMonoid :: forall a. CommutativeMonoid (Bag a)
bagMonoid = CommutativeMonoid { .. } where   
  oplus (Bag a) (Bag b) = Bag (a ++ b)
  identity = Bag []

-- finite map is commutative monoid
mapMonoid :: forall k a. Ord k => CommutativeMonoid a -> CommutativeMonoid (Map k a)
mapMonoid m = CommutativeMonoid { .. }  where
  oplus x y = let CommutativeMonoid {..} = m in unionWith oplus x y
  identity = empty

--singleton bag
singletonBag :: forall a. a -> Bag a
singletonBag b = Bag [b]

--tupled monoid
pairMonoid :: forall t t1.CommutativeMonoid t -> CommutativeMonoid t1 -> CommutativeMonoid (t, t1)
pairMonoid m1 m2 = CommutativeMonoid {..} where
  identity = (identity1, identity2)
  oplus (l1, l2) (r1, r2) = (oplus1 l1 r1, oplus2 l2 r2) 
  (oplus1, identity1) = let CommutativeMonoid {..} = m1 in (oplus, identity)
  (oplus2, identity2) = let CommutativeMonoid {..} = m2 in (oplus, identity)

-- Generic Semiring
data GenericSemiring alg a = GenericSemiring {monoid :: CommutativeMonoid a, 
                                              algebra :: alg a}

class GenericSemiringStructure alg free uniformer | alg -> free, alg -> uniformer where 
  freeSemiring :: GenericSemiring alg (Bag free)
  liftedSemiring :: (Ord c) => GenericSemiring alg a -> alg c -> GenericSemiring alg (Map c a)
  pairSemiring :: GenericSemiring alg a -> GenericSemiring alg b -> GenericSemiring alg (a,b)
  shom :: GenericSemiring alg a -> Bag free -> a {- for inefficient impl. -}
  makeAlgebra :: (CommutativeMonoid m) -> (alg a) -> (m->[a]) -> (a -> m) -> alg m
  pairAlgebra :: alg a -> alg b -> alg (a,b)
  freeAlgebra :: alg free
  hom :: alg a -> free -> a                      {- for inefficient impl. -}
  freeSemiring = GenericSemiring {..}
    where
      monoid = bagMonoid
      algebra = makeAlgebra bagMonoid freeAlgebra items singletonBag
  liftedSemiring s a = GenericSemiring {monoid=monoid', algebra=algebra'}
    where
      monoid' = let GenericSemiring {..} = s in mapMonoid monoid
      algebra' = makeAlgebra (mapMonoid (monoid s)) (pairAlgebra a (algebra s)) assocs (uncurry singleton)
  shom (GenericSemiring {..}) = sh
    where 
      CommutativeMonoid {..} = monoid
      sh (Bag b) = foldr oplus identity (map (hom algebra) b)
  pairSemiring s1 s2 = GenericSemiring {monoid=monoid', algebra=algebra'} 
    where 
      monoid' = pairMonoid (monoid s1) (monoid s2)
      algebra' = pairAlgebra (algebra s1) (algebra s2)
  foldingAlgebra :: (a -> a -> a) -> a -> uniformer a -> alg a



-- combinators with optimizations

-- Generator + Filter = Generator
infixl 5 >==
(>==) :: forall (alg :: * -> *) free (uniformer :: * -> *) c b k.
                        (GenericSemiringStructure alg free uniformer, Ord c) =>
                        (GenericSemiring alg (Map c b) -> Map k b)
                        -> (k -> Bool, alg c)
                        -> GenericSemiring alg b
                        -> b
(>==) pgen (ok, bt) bts = 
  let res = pgen (liftedSemiring bts bt)
      CommutativeMonoid {..} = monoid bts
  in foldr oplus identity [ v | (k, v) <- assocs res, ok k ]

-- Generator + Aggregator = Result
infixl 5 >=>
(>=>) :: forall (alg :: * -> *) free (uniformer :: * -> *) b k.
         (GenericSemiringStructure alg free uniformer) =>
             (GenericSemiring alg b -> b) -> GenericSemiring alg b -> b
(>=>) pgen bts = pgen bts
       
-- Generator_A + Transfomer_{A->B} = Generator_B
infixl 5 >=<
(>=<) :: forall (alg :: * -> *) free (uniformer :: * -> *) 
         (alg' :: * -> *) free' (uniformer' :: * -> *)
                          c.
         (GenericSemiringStructure alg free uniformer,
          GenericSemiringStructure alg' free' uniformer') =>
          (GenericSemiring alg' c -> c) -> 
           (GenericSemiring alg c -> GenericSemiring alg' c) -> 
               GenericSemiring alg c -> c
(>=<) pgen trans = pgen . trans

-- aliaces
filterBy :: forall (alg :: * -> *) free (uniformer :: * -> *) c b k.
                           (GenericSemiringStructure alg free uniformer, Ord c) =>
                           (GenericSemiring alg (Map c b) -> Map k b)
                           -> (k -> Bool, alg c)
                           -> GenericSemiring alg b
                           -> b
filterBy = (>==)

aggregateBy :: forall (alg :: * -> *) free (uniformer :: * -> *) b k.
         (GenericSemiringStructure alg free uniformer) =>
             (GenericSemiring alg b -> b) -> GenericSemiring alg b -> b
aggregateBy = (>=>)

transformBy :: forall (alg :: * -> *) free (uniformer :: * -> *) 
         (alg' :: * -> *) free' (uniformer' :: * -> *)
                          c.
         (GenericSemiringStructure alg free uniformer,
          GenericSemiringStructure alg' free' uniformer') =>
          (GenericSemiring alg' c -> c) -> 
           (GenericSemiring alg c -> GenericSemiring alg' c) -> 
               GenericSemiring alg c -> c
transformBy = (>=<)



-- combinators without optimizations 
infixl 5 >##
(>##) :: (GenericSemiringStructure alg free uniformer) =>
           (GenericSemiring alg (Bag free) -> Bag free)
           -> (b -> Bool, alg b) -> GenericSemiring alg (Bag free) -> Bag free
(>##) pgen (ok, bt) _ = filterB (ok.hom bt) bag
  where bag = pgen freeSemiring
{-the given semiring will be neglected by the result of this operator -}
        
infixl 5 >#>
(>#>) :: (GenericSemiringStructure alg free uniformer) =>
     (GenericSemiring alg (Bag free) -> Bag free)
     -> GenericSemiring alg a -> a
(>#>) pgen bts = shom bts (pgen freeSemiring)


-- operator to replace 'ok . hom' by 'ok <.> alg'
infix 6 <.>
(<.>) :: forall t t1. t -> t1 -> (t, t1)
(<.>) ok alg = (ok, alg)


-- aggregator for generating all candidates passing tests
result :: forall (alg :: * -> *) free (uniformer :: * -> *).
                         GenericSemiringStructure alg free uniformer =>
                         GenericSemiring alg (Bag free)
result = freeSemiring


-- aggregator based on the usual semirings
genAlgebraFromSemiring :: forall free (uniformer :: * -> *) (alg :: * -> *) a.
                          GenericSemiringStructure alg free uniformer =>
                                                       (a -> a -> a)
                                                           -> a
                                                           -> (a -> a -> a)
                                                           -> a
                                                           -> uniformer a
                                                           -> GenericSemiring alg a
genAlgebraFromSemiring op iop ot iot mf = GenericSemiring {..} where
  monoid = CommutativeMonoid {..} where
    oplus a b = a `op` b
    identity = iop
  algebra = foldingAlgebra ot iot mf

sumproductBy :: forall free (uniformer :: * -> *) (alg :: * -> *) a.
                               (GenericSemiringStructure alg free uniformer, Num a) =>
                               uniformer a -> GenericSemiring alg a
sumproductBy = genAlgebraFromSemiring (+) 0 (*) 1

data AddIdentity a = AddIdentity a | Identity deriving (Show, Eq, Read)
instance (Ord a) => Ord (AddIdentity a) where
  compare Identity Identity = EQ
  compare Identity (AddIdentity _) = LT
  compare (AddIdentity _) Identity = GT
  compare (AddIdentity a) (AddIdentity b) = compare a b

addIdentity :: forall a. a -> AddIdentity a
addIdentity a = AddIdentity a

-- max-sum semiring 

maxMonoSumBy :: forall free (uniformer :: * -> *) (alg :: * -> *) a.
                               (GenericSemiringStructure alg free uniformer, Ord a) =>
                               (a -> a -> a)
                               -> a
                               -> uniformer (AddIdentity a)
                               -> GenericSemiring alg (AddIdentity a)
maxMonoSumBy mplus mid mf = genAlgebraFromSemiring max Identity plus (AddIdentity mid) mf
  where plus Identity _ = Identity
        plus _ Identity = Identity
        plus (AddIdentity a) (AddIdentity b) = AddIdentity (a `mplus` b)

-- max-MonoSum with computation
maxMonoSumsolutionXBy :: forall free (uniformer :: * -> *) a t (alg :: * -> *).
                         (GenericSemiringStructure alg free uniformer, Ord a) =>
                           (a -> a -> a)
                               -> a
                               -> GenericSemiring alg t
                               -> uniformer (AddIdentity a)
                               -> GenericSemiring alg (AddIdentity a, t)
maxMonoSumsolutionXBy mplus mid s mf = GenericSemiring {..} where
  monoid = CommutativeMonoid {..} where
    oplus (a, x) (b, y) 
      = case compare a b of
          EQ -> (a, x `oplus'` y)
          LT -> (b, y)
          GT -> (a, x)
    identity = (Identity, identity')
  algebra = pairAlgebra maxMonoSumAlgebra algebra'
  maxMonoSumAlgebra = let GenericSemiring {..} = maxMonoSumBy mplus mid mf in algebra
  (monoid', algebra') = let GenericSemiring {..} = s in (monoid, algebra)
  (oplus', identity') = let CommutativeMonoid {..} = monoid' in(oplus, identity)

-- max-k
maxMonoSumKBy :: forall a free (uniformer :: * -> *) (alg :: * -> *).
                                (GenericSemiringStructure alg free uniformer, Ord a) =>
                                (a -> a -> a)
                                -> a
                                -> Int
                                -> uniformer (AddIdentity a)
                                -> GenericSemiring alg [AddIdentity a]
maxMonoSumKBy mplus mid k mf = GenericSemiring {..} where
    monoid = CommutativeMonoid {..} where
        oplus x y = take k (map head (group (reverse (sort (x ++ y)))))
        identity = []
    algebra = makeAlgebra monoid maxMonoSumAlgebra id sing
    sing a = [a]
    maxMonoSumAlgebra = let GenericSemiring {..} = maxMonoSumBy mplus mid mf in algebra

-- max-solution-k 
maxMonoSumsolutionXKBy :: forall a free (uniformer :: * -> *) b (alg :: * -> *).
                                         (GenericSemiringStructure alg free uniformer, Ord a) =>
                                         (a -> a -> a)
                                         -> a
                                         -> GenericSemiring alg b
                                         -> Int
                                         -> uniformer (AddIdentity a)
                                         -> GenericSemiring alg [(AddIdentity a, b)]
maxMonoSumsolutionXKBy mplus mid s k mf = GenericSemiring {..} where
    monoid = CommutativeMonoid {..} where
        oplus x y = 
            let std = reverse (sortBy fstCmp (x ++ y))
                grpd = groupBy (\a b -> fstCmp a b == EQ) std
                fstCmp a b = compare (fst a) (fst b)
                op (a, x) (_, y) = (a, x `oplus'` y)
            in take k (map (foldr1 op) grpd)
        identity = []
        (oplus', identity') = let CommutativeMonoid {..} = monoid' in (oplus, identity)
    algebra = makeAlgebra monoid (pairAlgebra maxMonoSumAlgebra algebra') id sing
    sing a = [a]
    maxMonoSumAlgebra = let GenericSemiring {..} = maxMonoSumBy mplus mid mf in algebra
    (monoid', algebra') = let GenericSemiring {..} = s in (monoid, algebra)

-- max-sum
maxsumBy :: forall free (uniformer :: * -> *) (alg :: * -> *) a.
                           (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                           uniformer (AddIdentity a) -> GenericSemiring alg (AddIdentity a)
maxsumBy = maxMonoSumBy (+) 0

maxsumKBy :: forall a free (uniformer :: * -> *) (alg :: * -> *).
                            (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                            Int
                            -> uniformer (AddIdentity a)
                            -> GenericSemiring alg [AddIdentity a]
maxsumKBy = maxMonoSumKBy (+) 0

maxsumsolutionXKBy :: forall a free (uniformer :: * -> *) b (alg :: * -> *).
                                     (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                     GenericSemiring alg b
                                     -> Int
                                     -> uniformer (AddIdentity a)
                                     -> GenericSemiring alg [(AddIdentity a, b)]
maxsumsolutionXKBy = maxMonoSumsolutionXKBy (+) 0


maxsumsolutionXBy :: forall free (uniformer :: * -> *) a t (alg :: * -> *).
                                    (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                    GenericSemiring alg t
                                    -> uniformer (AddIdentity a)
                                    -> GenericSemiring alg (AddIdentity a, t)
maxsumsolutionXBy = maxMonoSumsolutionXBy (+) 0


maxsumsolutionBy :: forall a (alg :: * -> *) free (uniformer :: * -> *).
                                   (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                   uniformer (AddIdentity a)
                                   -> GenericSemiring alg (AddIdentity a, Bag free)
maxsumsolutionBy = maxsumsolutionXBy freeSemiring


maxsumsolutionKBy :: forall a (alg :: * -> *) free (uniformer :: * -> *).
                                    (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                    Int
                                    -> uniformer (AddIdentity a)
                                    -> GenericSemiring alg [(AddIdentity a, Bag free)]
maxsumsolutionKBy = maxsumsolutionXKBy freeSemiring

--max prod (on positive numbers)
maxprodBy :: forall free (uniformer :: * -> *) (alg :: * -> *) a.
                            (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                            uniformer (AddIdentity a) -> GenericSemiring alg (AddIdentity a)
maxprodBy = maxMonoSumBy (*) 1

maxprodKBy :: forall a free (uniformer :: * -> *) (alg :: * -> *).
                             (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                             Int
                             -> uniformer (AddIdentity a)
                             -> GenericSemiring alg [AddIdentity a]
maxprodKBy = maxMonoSumKBy (*) 1


maxprodsolutionXKBy :: forall a free (uniformer :: * -> *) b (alg :: * -> *).
                       (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                         GenericSemiring alg b
                             -> Int
                             -> uniformer (AddIdentity a)
                             -> GenericSemiring alg [(AddIdentity a, b)]
maxprodsolutionXKBy = maxMonoSumsolutionXKBy (*) 1

maxprodsolutionXBy :: forall free (uniformer :: * -> *) a t (alg :: * -> *).
                      (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                     GenericSemiring alg t
                                     -> uniformer (AddIdentity a)
                                     -> GenericSemiring alg (AddIdentity a, t)
maxprodsolutionXBy = maxMonoSumsolutionXBy (*) 1

maxprodsolutionBy :: forall a (alg :: * -> *) free (uniformer :: * -> *).
                     (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                    uniformer (AddIdentity a)
                                    -> GenericSemiring alg (AddIdentity a, Bag free)
maxprodsolutionBy = maxprodsolutionXBy freeSemiring

maxprodsolutionKBy :: forall a (alg :: * -> *) free (uniformer :: * -> *).
                      (GenericSemiringStructure alg free uniformer, Ord a, Num a) =>
                                     Int
                                     -> uniformer (AddIdentity a)
                                     -> GenericSemiring alg [(AddIdentity a, Bag free)]
maxprodsolutionKBy = maxprodsolutionXKBy freeSemiring

-- reverse order to make `max` `min`
revOrd :: forall a. a -> RevOrd a
revOrd a = RevOrd a

data RevOrd a = RevOrd a 
           deriving (Eq, Show, Read)

instance (Num a) => (Num (RevOrd a)) where
  (+) (RevOrd a) (RevOrd b) = RevOrd (a + b)
  (*) (RevOrd a) (RevOrd b) = RevOrd (a * b) 
  (-) (RevOrd a) (RevOrd b) = RevOrd (a - b)
  negate (RevOrd a) = RevOrd (negate a)
  abs (RevOrd a) = RevOrd (abs a)
  signum (RevOrd a) = RevOrd (signum a)
  fromInteger a = RevOrd (fromInteger a)
  

instance (Ord a) => (Ord (RevOrd a)) where
  compare (RevOrd a) (RevOrd b) = 
      case compare a b of 
        EQ -> EQ
        LT -> GT
        GT -> LT