{-# LANGUAGE TemplateHaskell #-}

{-|
    Module used internally by the "Derive" module for binding group calculation.
 -}

module Text.GRead.Derive.BindingGroup (
    -- * Binding group
      BindingGroup
    , getBindingGroup
    , showBindingGroup
    -- * Template Haskell helpers
    , UserType(..)
    , getUserType
    , unrollApp
    ) where

import Language.Haskell.TH.Syntax

import           Data.List (nub, intersect)
import           Data.Maybe (fromJust)
import qualified Data.Map as M

import qualified Data.Graph.Inductive.Graph as G
import qualified Data.Graph.Inductive.Tree as G
import qualified Data.Graph.Inductive.Query.BFS as G
import qualified Data.Graph.Inductive.NodeMap as G

import Control.Monad (foldM)

-- | Uniform representation of 'data' and 'newtype'
data UserType  = UserD Name [Name] [Con]

-- | Edges can be Strong (direct) or Weak (through variable)
data EdgeType  = Strong [Type] | Weak [Type] deriving (Show, Eq)

-- | Graph of our type dependencies
type TypeGraph   = G.Gr      Name EdgeType
type TypeContext = G.Context Name EdgeType

-- | The next edge, parent and type
type NextEdge  = (EdgeType, G.Node)

-- | Map for looking up vars
type VarMap    = M.Map Name Type

-- This will be our 'inherited' attributes, as we walk the tree of types
-- type InheritState = (NextEdge, VarMap, TypeGraph)
data InheritState = IS { nextEdge  :: NextEdge
                       , varMap    :: VarMap
                       , typeGraph :: TypeGraph
                       , bindings  :: M.Map Name VarMap -- Also used as a loop-check
                       -- Actually synthesized, but here to enable tail recursion
                       , nodeMap   :: G.NodeMap Name
                       , contexts  :: [TypeContext] 
                       }

type SynthesizedState = (G.NodeMap Name, [TypeContext])

-- | Mapping from types to the list of types (with specific constructor instantiations, the '[[Type]]') in the same binding group.
type BindingGroup = [(Name, [(Name, [[Type]])])]

startState :: Name -> InheritState
startState firstType = 
    let (firstNode, firstNodemap) = G.mkNode G.new firstType
        firstContext              = ([], fst firstNode, snd firstNode, []) 
    in IS { nextEdge  = (Strong [], fst firstNode)
          , varMap    = M.empty
          , bindings  = M.singleton firstType M.empty
          , typeGraph = firstContext G.& G.empty
          , nodeMap   = firstNodemap
          , contexts  = [firstContext]
          }

returnState :: InheritState -> Q SynthesizedState
returnState state = return (nodeMap state, contexts state)

-- Show functions for console debugging
showBindingGroup :: Name -> Q Exp
showBindingGroup name = do
    bgroup <- getBindingGroup name
    let bgroup' = map show bgroup
    [|bgroup'|]
{-
showTypeGraph :: Name -> Q Exp
showTypeGraph name = do
    tgraph <- getTypeGraph name
    let tgraph' = show tgraph
    [|tgraph'|]
-}
-- | Find cyclic type dependencies (binding groups)
getBindingGroup :: Name -> Q BindingGroup
getBindingGroup name = do
    typegr <- getTypeGraph name
    -- Broken in 'labeled' pieces to make it better readable
    let loopClosers  = G.pre typegr 0
        loops        = map (\x -> G.esp 0 x typegr) loopClosers
        bindingGroup = nub $ concat loops
        -- Create the bindings first, we'll look them up when constructing the final binding group.
        bindings     = typeBindings typegr bindingGroup
    return $ map (\x -> (typeName typegr x, outgoingBindings typegr x bindings)) bindingGroup
    -- TODO Do we need to take into account the case that the 'top' type takes arguments and add a variable binding?

typeName :: TypeGraph -> G.Node -> Name
typeName typegr node = fromJust $ G.lab typegr node

typeBindings :: TypeGraph -> [G.Node] -> M.Map G.Node [[Type]]
typeBindings typegr bindingGroup = M.fromList $ map (\x -> (x, bindings x)) bindingGroup
    where bindings x = nub $ map (getBindings . thd) $ G.inn typegr x

-- Only strong edges
outgoingBindings :: TypeGraph -> G.Node -> M.Map G.Node [[Type]] -> [(Name, [[Type]])]
outgoingBindings typegr node bindings = map (\x -> (typeName typegr x, maybe (error $ show x) id $ M.lookup x bindings)) edgesOut
    where edgesOut = filter ((/=) node) $ intersect (M.keys bindings) $ map getTargetNode $ filter strongEdge $ nub $ G.out typegr node
          strongEdge (_, _, Strong _) = True
          strongEdge _                = False
          getTargetNode (_, n, _)     = n

-- TODO This suggests a better (wrapping) type
getBindings :: EdgeType -> [Type]
getBindings (Strong types) = types
getBindings (Weak   types) = types

-- Build the typegraph
getTypeGraph :: Name -> Q TypeGraph
getTypeGraph name = do
    utype@(UserD uname _ _) <- getUserType name
    (_, contexts')          <- extendGraphType (startState uname) utype
    return $ mkGraph' contexts'

mkGraph' :: (Eq a, Eq b) => [G.Context a b] -> G.Gr a b
mkGraph' contexts = 
    let nodes = nub $ map G.labNode' contexts
        edges = concatMap (\x -> (G.inn' x) ++ (G.out' x)) contexts -- Don't nub, do that where it's really needed
    in G.mkGraph nodes edges

-- | Get a unified type for 'data' and 'newtype'
getUserType :: Name -> Q UserType
getUserType name = do
    info <- reify name
    case info of
        TyConI d -> case d of
            (DataD     _ uname args cs  _)  -> return $ UserD uname (map tyVarBndr2Name args) cs 
            (NewtypeD  _ uname args c   _)  -> return $ UserD uname (map tyVarBndr2Name args) [c]
            _                               -> scopeError
        _ -> scopeError
    where scopeError = error $ "Can only be used on algebraic datatypes (which " ++ (show name) ++ " isn't)" 


tyVarBndr2Name :: TyVarBndr -> Name
tyVarBndr2Name (PlainTV n)     =  n
tyVarBndr2Name (KindedTV n _)  =  n

extendGraphType :: InheritState -> UserType -> Q SynthesizedState
extendGraphType state (UserD _ _ cons) = do
    startState <- returnState state
    foldM (\(nodeMap', contexts') utype -> 
              extendGraph (state { contexts = contexts', nodeMap = nodeMap' }) utype) 
          startState
          (getCtx cons) 

-- Context of a data type, based on the arguments in it's constructors.  Not the context of a graph.
getCtx :: [Con] -> [Type] 
getCtx []                            = []
getCtx ((NormalC        _ args) :cs) = (map snd args) ++ (getCtx cs)
getCtx ((InfixC   argl  _ argr) :cs) = (snd argl) : ((snd argr) : (getCtx cs))
getCtx ((RecC           _ args) :cs) = (map thd args) ++ (getCtx cs) 
getCtx _                             = error "Error: the impossible happened in getCtx!"
-- Not exhaustive: missing ForallC

thd :: (a, b, c) -> c
thd (_, _, c) = c

extendGraph :: InheritState -> Type -> Q SynthesizedState
extendGraph state nextType = case nextType of
    VarT varname             -> extendGraphVar state varname
    ConT conname             -> extendGraphCon state conname
    app@(AppT _ _)           -> extendGraphApp state app 
    _                        -> error $ "Couldn't match: " ++ (show nextType)

extendGraphVar :: InheritState -> Name -> Q SynthesizedState
extendGraphVar state name = do
    case M.lookup name (varMap state) of
        Just (VarT _) -> returnState state
        Nothing       -> returnState state -- From toplevel
        -- Follow the var, but make the link weak
        Just vartype  -> extendGraph (state { nextEdge = (Weak [], snd $ nextEdge state) }) vartype

extendGraphCon :: InheritState -> Name -> Q SynthesizedState
extendGraphCon state name = do
    newState               <- insNode state name
    utype@(UserD _ args _) <- getUserType name
    case M.lookup name (bindings state) of
        Just oldVarMap | oldVarMap == varMap state || args == [] -> returnState newState
        _                                                        -> do
            extendGraphType newState { bindings = M.insert name (varMap state) (bindings state) } utype

insNode :: InheritState -> Name -> Q InheritState
insNode state name = do
    (UserD _ args _) <- getUserType name
    let ((nodeNr, _), newNodeMap) = G.mkNode (nodeMap state) name
        (edgeType, parent)        = nextEdge state    
        nextEdge'                 = (mkNextEdgeType edgeType args (varMap state), parent)
        newContexts               = ([nextEdge'], nodeNr, name, []) : (contexts state)
    return $ state { nextEdge    = (Strong [], nodeNr)
                   , typeGraph   = mkGraph' newContexts
                   , contexts    = newContexts
                   , nodeMap     = newNodeMap
                   }

mkNextEdgeType :: EdgeType -> [Name] -> VarMap -> EdgeType
mkNextEdgeType (Strong _) args vm = Strong $ mkNextEdgeType' args vm []
mkNextEdgeType (Weak   _) args vm = Weak   $ mkNextEdgeType' args vm []

mkNextEdgeType' :: [Name] -> VarMap -> [Type] -> [Type]
mkNextEdgeType' []       _  types = types
mkNextEdgeType' (a:args) vm types = case M.lookup a vm of
    Nothing -> mkNextEdgeType' args vm ((VarT a):types)
    Just t  -> mkNextEdgeType' args vm (t       :types)

extendGraphApp :: InheritState -> Type -> Q SynthesizedState
extendGraphApp state app = do
    let (app':appargs) = replaceVars (unrollApp app) (varMap state)
    varmap' <- extendVarMap app' appargs (varMap state)
    extendGraph (state { varMap = varmap' }) app'

extendVarMap :: Type -> [Type] -> VarMap -> Q VarMap
extendVarMap utype appargs varmap =
    case utype of
        ConT uname -> do
            (UserD _ args _) <- getUserType uname
            return $ M.union (M.fromList $ zip args appargs) varmap
        _ -> return varmap

-- | Get the types of a type application
unrollApp :: Type -> [Type]
unrollApp app = unrollApp' app []
    where unrollApp' :: Type -> [Type] -> [Type]
          unrollApp' (AppT sub@(AppT _ _) arg) args = unrollApp' sub (arg:args)
          unrollApp' (AppT top arg)            args = top:(arg:args)
          unrollApp' _                         _    = error "Error: the impossible happened in unrollApp"
          -- Not exhaustive: missing ArrowT, ContT, ForallT, ListT, ...

replaceVars :: [Type] -> VarMap -> [Type]
replaceVars []              _  = []
replaceVars (t@(VarT v):ts) vm = case M.lookup v vm of
                                    Nothing -> t  : replaceVars ts vm
                                    Just t' -> t' : replaceVars ts vm
replaceVars (t:ts)          vm = t : replaceVars ts vm