{-# OPTIONS_GHC -Wall #-}
{-# Language TypeOperators #-}
{-# Language TypeFamilies #-}
{-# Language FlexibleInstances #-}

module Dvda.FunGraph ( FunGraph
                     , ToFunGraph
                     , NumT
                     , (:*)(..)
                     , MVS(..)
                     , toFunGraph
                     , countNodes
                     , fgInputs
                     , fgOutputs
                     , fgLookupGExpr
                     , fgReified
                     , topSort
--                     , fgGraph
                     , nodelistToFunGraph
                     , exprsToFunGraph
                     ) where

import Control.Applicative
import Data.Foldable ( Foldable )
import qualified Data.Foldable as F
import qualified Data.Graph as Graph
import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.Traversable ( Traversable )
import qualified Data.Traversable as T

import Dvda.Expr
import Dvda.Reify ( ReifyGraph(..), reifyGraphs )

data FunGraph a = FunGraph { fgGraph :: Graph.Graph
                           , fgInputs :: [MVS (GExpr a Int)]
                           , fgOutputs :: [MVS Int]
                           , fgReified :: [(Int, GExpr a Int)]
                           , fgLookupGExpr :: Int -> Maybe (GExpr a Int)
                           , fgVertexFromKey :: Int -> Maybe Int
                           , fgNodeFromVertex :: Int -> (GExpr a Int, Int, [Int])
                           }

instance Show a => Show (FunGraph a) where
  show fg = "FunGraph\ninputs:\n" ++ show (fgInputs fg) ++ "\noutputs:\n" ++ show (fgOutputs fg) ++ "\ngraph:\n" ++ show (fgGraph fg)

---- | matrix or vector or scalar
data MVS a = Mat [[a]] | Vec [a] | Sca a deriving Show

instance Functor MVS where
  fmap f (Sca x)  = Sca (f x)
  fmap f (Vec xs) = Vec (map f xs)
  fmap f (Mat xs) = Mat (map (map f) xs)

instance Foldable MVS where
  foldr f x0 (Sca x)  = foldr f x0 [x]
  foldr f x0 (Vec xs) = foldr f x0 xs
  foldr f x0 (Mat xs) = foldr f x0 (concat xs)

instance Traversable MVS where
  traverse f (Sca x)  = Sca <$> f x
  traverse f (Vec xs) = Vec <$> T.traverse f xs
  traverse f (Mat xs) = Mat <$> T.traverse (T.traverse f) xs

class ToFunGraph a where
  type NumT a
  toMVSList :: a -> [MVS (Expr (NumT a))]
instance ToFunGraph (Expr a) where
  type NumT (Expr a) = a
  toMVSList x = [Sca x]
instance ToFunGraph [Expr a] where
  type NumT [Expr a] = NumT (Expr a)
  toMVSList x = [Vec x]
instance ToFunGraph [[Expr a]] where
  type NumT [[Expr a]] = NumT [Expr a]
  toMVSList x = [Mat x]

data a :* b = a :* b deriving Show
infixr 6 :*
instance (ToFunGraph a, ToFunGraph b, NumT a ~ NumT b) => ToFunGraph (a :* b) where
  type NumT (a :* b) = NumT a
  toMVSList (x :* y) = toMVSList x ++ toMVSList y

-- | find any symbols which are parents of outputs, but are not supplied by the user
detectMissingInputs :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [(Int,GExpr a Int)] -> [GExpr a Int]
detectMissingInputs exprs gr = HS.toList $ HS.difference allGraphInputs allUserInputs
  where
    allUserInputs = let f (ESym name) acc = (GSym name):acc
                        f _ e = error $ "detectMissingInputs given non-ESym input \"" ++ show e ++ "\""
                    in HS.fromList $ foldr f [] (concatMap F.toList exprs)

    allGraphInputs = let f (_,(GSym name)) acc = (GSym name):acc
                         f _ acc = acc
                     in HS.fromList $ foldr f [] gr

-- | if the same input symbol (like ESym "x") is given at two different places throw an exception
findConflictingInputs :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [Expr a]
findConflictingInputs exprs = HS.toList redundant
  where
    redundant = snd $ foldl f (HS.empty, HS.empty) (concatMap F.toList exprs)
      where
        f (knownExprs, redundantExprs) expr@(ESym _)
          | HS.member expr knownExprs = (knownExprs, HS.insert expr redundantExprs)
          | otherwise = (HS.insert expr knownExprs, redundantExprs)
        f _ e = error $ "findConflictingInputs saw non-ESym input \"" ++ show e ++ "\""


-- | Take inputs and outputs which are of classes ToFunGraph (heterogenous lists of @Expr a@)
--   and traverse the outputs reifying all expressions and creating a hashmap of StableNames (stable pointers).
--   Once the hashmap is created, lookup the provided inputs and return a FunGraph which contains an
--   expression graph, input/output indices, and other useful functions. StableNames is non-deterministic
--   so this function may return graphs with more or fewer CSE's eliminated.
--   If CSE is then performed on the graph, the result is deterministic.
toFunGraph :: (Eq a, Hashable a, Show a, ToFunGraph b, ToFunGraph c, NumT b ~ a, NumT c ~ a)
              => b -> c -> IO (FunGraph a)
toFunGraph inputs outputs = mvsToFunGraph (toMVSList inputs) (toMVSList outputs)

mvsToFunGraph :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [MVS (Expr a)] -> IO (FunGraph a)
mvsToFunGraph inputMVSExprs outputMVSExprs = do
  -- reify the outputs
  (ReifyGraph rgr, outputMVSIndices) <- reifyGraphs outputMVSExprs
  let fg = nodelistToFunGraph rgr inputMVSGExprs outputMVSIndices
      inputMVSGExprs = map (fmap f) inputMVSExprs
        where
          f (ESym name) = (GSym name)
          f x = error $ "ERROR: mvsToFunGraph given non-ESym input \"" ++ show x ++ "\""
  return $ case (detectMissingInputs inputMVSExprs rgr, findConflictingInputs inputMVSExprs) of
    ([],[]) -> fg
    (xs,[]) -> error $ "mvsToFunGraph found inputs that were not provided by the user: " ++ show xs
    ( _,xs) -> error $ "mvsToFunGraph found idential inputs set more than once: " ++ show xs

nodelistToFunGraph :: [(Int,GExpr a Int)] -> [MVS (GExpr a Int)] -> [MVS Int] -> FunGraph a
nodelistToFunGraph rgr inputMVSIndices outputMVSIndices =
  FunGraph { fgGraph = gr
           , fgInputs = inputMVSIndices
           , fgOutputs = outputMVSIndices
           , fgLookupGExpr = lookupG
           , fgReified = rgr
           , fgVertexFromKey = lookupKey
           , fgNodeFromVertex = lookupVertex
           }
  where
    -- make sure all the inputs are symbolic, and find their indices in the Expr graph
    (gr, lookupVertex, lookupKey) = Graph.graphFromEdges $ map (\(k,gexpr) -> (gexpr, k, getParents gexpr)) rgr
    lookupG k = (\(g,_,_) -> g) <$> lookupVertex <$> lookupKey k


---------------------------------- utilities -----------------------------
countNodes :: FunGraph a -> Int
countNodes = length . Graph.vertices . fgGraph

topSort :: FunGraph a -> [Int]
topSort fg = map ((\(_,k,_) -> k) . (fgNodeFromVertex fg)) $ Graph.topSort (fgGraph fg)

-- | make a FunGraph out of outputs, automatically detecting the proper inputs
exprsToFunGraph :: (Eq a, Show a, Hashable a) => [Expr a] -> IO (FunGraph a)
exprsToFunGraph outputs = do
  let getSyms :: [Expr a] -> [Sym]
      getSyms exprs = HS.toList $ foldr (\acc expr -> foldExpr f expr acc) HS.empty exprs
        where
          f (ESym s) hs = HS.insert s hs
          f _ hs = hs
      inputs = map ESym $ getSyms outputs
  toFunGraph inputs outputs