{-# LANGUAGE MultiParamTypeClasses,FlexibleInstances,FlexibleContexts,FunctionalDependencies,UndecidableInstances,RankNTypes,ExplicitForAll,ScopedTypeVariables,NoMonomorphismRestriction,OverlappingInstances,TemplateHaskell  #-}

module GTA.Util.GenericSemiringStructureTemplate (genAlgebraDecl, genMapFunctionsDecl, genInstanceDecl, genAllDecl) where

import Language.Haskell.TH
import GTA.Util.TypeInfo 
import Data.Char
{-
reference: 
 http://www.haskell.org/haskellwiki/Template_haskell/Instance_deriving_example
-}

{- exported functions -}
genAlgebraDecl :: Name -> Q [Dec]
genAlgebraDecl typName =
  do (typeName,typeParams,constructors) <- typeInfo typName
     alg <- genAlgebraRecord typeName typeParams constructors
     return ([alg])

genMapFunctionsDecl :: Name -> Q [Dec]
genMapFunctionsDecl typName =
  do (typeName,typeParams,constructors) <- typeInfo typName
     alg <- genMapFunctionsRecord typeName typeParams constructors
     return ([alg])

genInstanceDecl :: Name -> Q [Dec]
genInstanceDecl typName =
  do (typeName,typeParams,constructors) <- typeInfo typName
     inst <- genSemiringInstance typeName typeParams constructors
     return ([inst])

genAllDecl :: Name -> Q [Dec]
genAllDecl typName =
  do alg <- genAlgebraDecl typName
     mf <- genMapFunctionsDecl typName
     inst <- genInstanceDecl typName
     return (alg ++ mf ++ inst)

{-
Given a data type like
 data BTree a = Node a (BTree a) (BTree a)
              | Leaf a
, this generates a record type corresponding to the algebra like
 data BTreeAlgebra b a = 
   BTreeAlgebra {
     node :: b -> a -> a -> a,
     leaf :: b -> a
   }
.
-}
genAlgebraRecord :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genAlgebraRecord typeName typeParams constructors = 
  let a = mkName "gta"
      newParams = typeParams++[PlainTV a]
      dataName = algebraName typeName
      funs = map genFun constructors -- functions corresponding to constructors
      con = recC dataName funs -- the constructor = the name
      genFun (name, params) = 
        varStrictType (funcName name) 
        (strictType notStrict (arrowConcat (map (\(VarT a) -> varT a) (replace freeType (VarT a) (map (\(_, t) -> t) params ++[VarT a])))))
      freeType = genFreeType typeName typeParams
  in dataD (cxt []) dataName newParams [con] []

{-
data BTreeMapFs b b' = BTreeMapFs {
         nodeF :: (b -> b'),
         leafF :: (b -> b')
       }
This is a set of functions to make types of values the same.
-}
genMapFunctionsRecord :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genMapFunctionsRecord typeName typeParams constructors = 
  let a = mkName "gta"
      newParams = typeParams++[PlainTV a]
      mapName = mapFunctionsName typeName
      funs = map genFun constructors' -- functions corresponding to constructors
      con = recC mapName funs -- the constructor = the name
      funcName' = mfFuncName . funcName
      constructors' = filter (\(_, x) -> length x > 0) (map dropFreeType constructors)
      dropFreeType (name, params) = (name, filter (/=freeType) (map (\(_, t) -> t) params))
      genFun (name, params) = 
        varStrictType (funcName' name) 
        (strictType notStrict (mkTupleType (map (\(VarT b) -> appT (appT arrowT (varT b)) (varT a)) params)))
      freeType = genFreeType typeName typeParams
  in dataD (cxt []) mapName newParams [con] []
     
mkTupleType :: [TypeQ] -> TypeQ
mkTupleType [a] = a
mkTupleType x = foldl appT (tupleT (length x)) x
{-
  instance GenericSemiringStructure (BTreeAlgebra b) (BTree b) (BTreeMapFunctions b) where
-}
genSemiringInstance :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genSemiringInstance typeName typeParams constructors = 
  let className = mkName "GenericSemiringStructure" 
      appfold e = foldl appT e . map (\(PlainTV a) -> varT a) 
      instanceType = appT (appT (appT (conT className) (appfold (conT dataName) typeParams)) (appfold (conT typeName) typeParams)) (appfold (conT mapName) typeParams)
      dataName = algebraName typeName
      mapName = mapFunctionsName typeName
--      funcs = [genBagFreeAlgebra typeName typeParams constructors,
--               genLiftedAlgebra typeName typeParams constructors,
--               genHom typeName typeParams constructors]
      funcs = [genFreeAlgebra typeName typeParams constructors,
               genHom typeName typeParams constructors,
               genPairAlgebra typeName typeParams constructors,
               genMakeAlgebra typeName typeParams constructors,
               genFoldingAlgebra typeName typeParams constructors]
  in instanceD (cxt []) instanceType funcs

{-
  freeAlgebra = BTreeAlgebra {..} where
     node = Node
     leaf = Leaf
-}
genFreeAlgebra :: forall t t1. Name -> t -> [(Name, t1)] -> DecQ
genFreeAlgebra typeName _ constructors = 
  let
    freeAlgebraName = (mkName "freeAlgebra")
    fieldEs = genWildcardFieldExp (map (\(n, _) -> funcName n) constructors)
    e = recConE (algebraName typeName) fieldEs
    decls = map genFunDecl constructors
    genFunDecl (n, _) = funD (funcName n) [clause [] (normalB (conE n)) []]
  in funD freeAlgebraName [clause [] (normalB e) decls]

{-
  pairAlgebra bt1 bt2 = BTreeAlgebra {..} 
    where
      node a (l1, l2) (r1, r2) = (node1 a l1 r1, node2 a l2 r2)
      leaf a = (leaf1 a, leaf2 a)
      (leaf1, node1) = let BTreeAlgebra {..} = bt1 in (leaf, node)
      (leaf2, node2) = let BTreeAlgebra {..} = bt2 in (leaf, node)
-}
genPairAlgebra :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genPairAlgebra typeName typeParams constructors = 
  let
    alg1 = mkName "algebra1"
    alg2 = mkName "algebra2"
    vps = map varP [alg1, alg2]
    fs = map (\(n, _)->funcName n) constructors
    binds = [recBind (algebraName typeName) fs (varE alg1) (name 1),
             recBind (algebraName typeName) fs (varE alg2) (name 2)]
    name i = mkName . (++show i) . nameBase
    bindExp ve = ve
    bindPat a = tupP [varP (name 1 a), varP (name 2 a)]
    newAlgebraName = (mkName "pairAlgebra")
    genBody _ n' pbs = tupE [foldl1 appE (varE (name 1 n'):vars 1), foldl1 appE (varE (name 2 n'):vars 2)]
      where
        varnames f = map (\(b, VarT a) -> case b of Just (VarT c) -> f c
                                                    otherwise -> a) pbs
        vars i = map varE (varnames (name i))
  in genAlgebraDec' typeName typeParams constructors binds newAlgebraName vps bindExp bindPat genBody

{-
  makeAlgebra (CommutativeMonoid {..}) bt frec fsingle = BTreeAlgebra {..}
    where  
    node a l r = foldr oplus identity [fsingle (node' a l' r') | l' <- frec l, r' <- frec r]
    leaf a = fsingle (leaf' a)
    (leaf', node') = let BTreeAlgebra {..} = bt in (leaf, node)

-}
genMakeAlgebra :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genMakeAlgebra typeName typeParams constructors = 
  let
    m = mkName "m"
    alg = mkName "alg"  
    frec = mkName "frec"
    fsingle = mkName "fsingle"
    vps = map varP [m, alg, frec, fsingle]
    fs = map (\(n, _)->funcName n) constructors
    binds = [recBind (algebraName typeName) fs (varE alg) name,
             monoidBind (varE m)]
    name = mkName . (++"gta") . nameBase
    bindExp ve = appE (varE frec) ve
    bindPat a = varP a
    newAlgebraName = (mkName "makeAlgebra")
    genComprBody _ n' pbs = appE (varE fsingle) (foldl1 appE (varE (name n'):vars))
      where vars = map (\(b, VarT a) -> case b of Just (VarT c) -> varE c
                                                  otherwise -> varE a) pbs
  in genAlgebraDec typeName typeParams constructors binds newAlgebraName vps bindExp bindPat genComprBody

{-
  foldingAlgebra op (BTreeMapFs {nodeF=(nodeF1),leafMF=(leafF1)}) = BTreeAlgebra {..}
    where
    node a l r = nodeF1 a `op` l `op` r
    leaf a = leafF1 a
-}
genFoldingAlgebra :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genFoldingAlgebra typeName typeParams constructors = 
  let
    mf = mkName "mf"
    op = mkName "op"
    iop = mkName "iop"
    vps = map varP [op, iop, mf]
    constructors' = filter hasNonRec constructors
    hasNonRec (_, ps) = length (filter (\(_, t) -> t /=freeType) ps) > 0
    fs = map (\(n, _)->mfFuncName(funcName n)) constructors'
    binds = [recBind (mapFunctionsName typeName) fs (varE mf) id]
    freeType = genFreeType typeName typeParams
    newAlgebraName = (mkName "foldingAlgebra")
    funcs _ n' pbs = let 
        nonrecs = map (\(b, VarT _) -> case b of Just (VarT _) -> 0
                                                 otherwise -> 1) pbs
        ids = tail(scanl (+) 0 nonrecs)
        f 0 _ a = Left a
        f 1 i b = Right (name i (mfFuncName n'), b)
        in zipWith3 f nonrecs ids pbs        
    name i =  mkName . (++show i) . nameBase
    genVarbinds n n' pbs = 
        let funs = funcs n n' pbs
            ns = map (\(Right (n, _)) -> varP n) (filter fr funs)
            fr (Left _) = False
            fr (Right _) = True
        in if length ns == 0 then [] else [valD (tupP ns) (normalB (varE (mfFuncName n'))) []]
    genBody n n' pbs = if pbs == [] then varE iop else foldl1 (\a b -> appE (appE (varE op) a) b) vars
      where
        funs = funcs n n' pbs
        vars = map f funs
        f (Left (_, VarT a)) = varE a
        f (Right (fn, (_, VarT a))) = appE (varE fn) (varE a)
  in genAlgebraDec'' typeName typeParams constructors binds newAlgebraName vps genBody genVarbinds



{-  
hom (BTreeBAlgebra {..}) = h
  where
    h (NodeB a l r) = nodeB a (h l) (h r)
    h (LeafB a) = leafB a
-}
genHom :: forall t.Name -> [TyVarBndr] -> [(Name, [(t, Type)])] -> DecQ
genHom typeName typeParams constructors = 
  let
    fs = map (\(n, _)->funcName n) constructors
    vps = [recPat (algebraName typeName) fs id]
    freeType = genFreeType typeName typeParams
    decls = [funD h (map genClause constructors)]
    h = mkName "h"
    genClause (n, ps) = let
      n' = funcName n
      ts = map (\(_, t) -> t) ps
      pbs = zipWith mkpb ts (newVars "rv")
      mkpb t v = if t == freeType then (Just (), v) else (Nothing, t)
      pats = [conP n (map (\(_, VarT a) -> varP a) pbs)]
      subes = map (\(b, VarT a) -> case b of Just () -> appE (varE h) (varE a)
                                             otherwise -> varE a) pbs
      b = foldl appE (varE n') subes
      in clause pats (normalB b) []
  in funD (mkName "hom") [clause vps (normalB (varE h)) decls]

{-
TODO: this function has been split into several parts. write comments!

e.g., to generate the following,

  liftedAlgebra bts bt = BTreeAlgebra {..}
    where  
      node a l r = 
        foldr oplus identity [singleton (nodebt a kll krr) (nodebt' a vll vrr) | (kll, vll) <- assocs l, (krr, vrr) <- assocs r]
      leaf a = singleton (leafbt a) (leafbt' a)
      CommutativeMonoid {..} = mapMonoid (monoid bts)
      (leafbt, nodebt) = let BTreeAlgebra {..} = bt in (leaf, node)
      (leafbt', nodebt') = let BTreeAlgebra {..} = algebra bts in (leaf, node)

the function arguments are
 - (typeName, typeParams, constructors) is of typeInfo ''BTree
 - binds is a list of valDs for    
      CommutativeMonoid {..} = mapMonoid (monoid bts)
      (leafbt, nodebt) = let BTreeAlgebra {..} = bt in (leaf, node)
      (leafbt', nodebt') = let BTreeAlgebra {..} = algebra bts in (leaf, node)
 - newAlgebraName is 'liftedAlgebra'
 - vps is a list of argument patterns of the 'liftedAlgebra', i.e., [bts, bt] 
 - bindExp generates expressions (RHS of <-) of binds in the comprehension, 
   i.e., bindExp v = assocs v
 - bindPat generates patterns (LHS of <-) of binds in the comprehension,
   i.e., bindPat r = (kr, vr)
 - genComprBody generates the body of the comprehention from 
     n   ... the constructor name
     n'  ... the function name corresponding to the constructor
     pbs ... a list of (Maybe Type, Type) generated from the constructor's type 
               Here, each value of Type is of VarT x.
               If the variable has the same type as the data structure,
                the first part is Just (VarT y) 
                s.t. a bind "bindPat y <- bindExp x" is generated.
               Otherwise it is Nothing.
               For Node of BTree, pbs = [(Nothing, VarT a), 
                                         (Just (VarT ll), VarT l),
                                         (Just (VarT rr), VarT r)]
-}
genAlgebraDec :: forall t.
                     Name
                     -> [TyVarBndr]
                     -> [(Name, [(t, Type)])]
                     -> [DecQ]
                     -> Name
                     -> [PatQ]
                     -> (ExpQ -> ExpQ)
                     -> (Name -> PatQ)
                     -> (Name -> Name -> [(Maybe Type, Type)] -> ExpQ)
                     -> DecQ
genAlgebraDec typeName typeParams constructors binds newAlgebraName vps bindExp bindPat genComprBody = 
  let
    genVarbinds _ _ _ = []
    genBody n n' pbs =
          if and (map ((==Nothing).fst) pbs) 
          then -- has no recursive position 
              genComprBody n n' pbs                    
          else -- has recursive positions
              let 
                  bigOp = foldl1 appE (map (varE.mkName) ["foldr", "oplus", "identity"])
                  
                  varbinds = map bind (filter ((/=Nothing).fst) pbs)
                  bind (Just(VarT a),VarT b) = bindS (bindPat a) (bindExp (varE b))
                  compr = compE (varbinds++[noBindS (genComprBody n n' pbs)])
              in appE bigOp compr
  in genAlgebraDec'' typeName typeParams constructors binds newAlgebraName vps genBody genVarbinds

genAlgebraDec' :: forall t.
                     Name
                        -> [TyVarBndr]
                        -> [(Name, [(t, Type)])]
                        -> [DecQ]
                        -> Name
                        -> [PatQ]
                        -> (ExpQ -> ExpQ)
                        -> (Name -> PatQ)
                        -> (Name -> Name -> [(Maybe Type, Type)] -> ExpQ)
                        -> DecQ
genAlgebraDec' typeName typeParams constructors binds newAlgebraName vps bindExp bindPat genBody = 
  let genVarbinds _ _ pbs = map bind (filter ((/=Nothing).fst) pbs)
        where bind (Just(VarT a),VarT b) = valD (bindPat a) (normalB (bindExp (varE b))) []
  in genAlgebraDec'' typeName typeParams constructors binds newAlgebraName vps genBody genVarbinds

genAlgebraDec'' :: forall t.
                      Name
                          -> [TyVarBndr]
                          -> [(Name, [(t, Type)])]
                          -> [DecQ]
                          -> Name
                          -> [PatQ]
                          -> (Name -> Name -> [(Maybe Type, Type)] -> ExpQ)
                          -> (Name -> Name -> [(Maybe Type, Type)] -> [DecQ])
                          -> DecQ
genAlgebraDec'' typeName typeParams constructors binds newAlgebraName vps genBody genVarbinds = 
    let fieldEs = genWildcardFieldExp (map (\(n, _) -> funcName n) constructors)
        e = recConE (algebraName typeName) fieldEs
        freeType = genFreeType typeName typeParams
        decls = map genFunDecl constructors ++ binds
        genFunDecl (n, ps) = 
          let n' = funcName n
              ts = map (\(_, t) -> t) ps
              pbs = zipWith3 mkpb ts (newVars "rv") (newVars "rvi")
              mkpb t v vv = if t == freeType then (Just vv, v) else (Nothing, t)
              pats = map (\(_, VarT a) -> varP a) pbs
              b = genBody n n' pbs
              varbinds = genVarbinds n n' pbs
          in funD n' [clause pats (normalB b) varbinds]
    in funD newAlgebraName [clause vps (normalB e) decls]

replace :: forall b. Eq b => b -> b -> [b] -> [b]
replace a b x = map (\c -> if c == a then b else c) x

arrowConcat :: [TypeQ] -> TypeQ
arrowConcat = foldr1 (\v x -> appT (appT arrowT v) x)

funcName :: Name -> Name
funcName = mkName . unCapalize . nameBase

unCapalize :: [Char] -> [Char]
unCapalize (x:y) = (toLower x):y

algebraName :: Name -> Name
algebraName typeName = mkName (nameBase typeName++"Algebra")

mapFunctionsName :: Name -> Name
mapFunctionsName typeName = mkName (nameBase typeName++"MapFs")

mfFuncName :: Name -> Name
mfFuncName = mkName . (++"F") . nameBase 

monoidBind :: ExpQ -> DecQ
monoidBind e = recBind (mkName "CommutativeMonoid") [mkName "oplus", mkName "identity"] e id

recBind :: Name -> [Name] -> ExpQ -> (Name -> Name) -> DecQ
recBind n fs e f = valD (recPat n fs f) (normalB e) []

recPat :: Name -> [Name] -> (Name -> Name) -> PatQ
recPat n fs f = recP n (genWildcardFieldPat f fs)

genFreeType :: Name -> [TyVarBndr] -> Type
genFreeType typeName typeParams = foldl1 AppT (ConT typeName:typeParams'')
  where typeParams'' = map (\(PlainTV a) -> VarT a) typeParams

genWildcardFieldExp :: [Name] -> [Q (Name, Exp)]
genWildcardFieldExp = map (\n -> fieldExp n (varE n)) 

genWildcardFieldPat :: (Name -> Name) -> [Name] -> [FieldPatQ]
genWildcardFieldPat f = map (\n -> fieldPat n (varP (f n))) 

newVars :: [Char] -> [Type]
newVars s = g 0 where g i = VarT (mkName (s ++ show i)) : g (i+1)