module Feldspar.Core.Reify
  ( Program (..)
  , showCore
  , showCoreWithSize
  , printCore
  , printCoreWithSize
  , runGraph
  , buildSubFun
  , startInfo
  ) where
import Control.Monad.State
import Control.Monad.Writer
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import Data.Unique
import Feldspar.Core.Types
import Feldspar.Core.Ref
import Feldspar.Core.Expr
import Feldspar.Core.Graph hiding (function, Function (..), Variable)
import qualified Feldspar.Core.Graph as Graph
import Feldspar.Core.Show
data Info = Info
  { 
    index :: NodeId
    
  , visited :: Map Unique NodeId
  }
type Reify a = WriterT [Node] (State Info) a
startInfo :: Info
startInfo = Info 0 Map.empty
runGraph :: Reify a -> Info -> (a, ([Node], Info))
runGraph graph info = (a, (nodes, info'))
  where
    ((a,nodes),info') = runState (runWriterT graph) info
newIndex :: Reify NodeId
newIndex = do
    info <- get
    put (info {index = succ (index info)})
    return (index info)
remember :: Data a -> NodeId -> Reify ()
remember a i = modify $ \info ->
    info {visited = Map.insert (dataId a) i (visited info)}
checkNode :: Data a -> Reify (Maybe NodeId)
checkNode a = gets ((Map.lookup (dataId a)) . visited)
node ::
    Data a -> Graph.Function -> Tuple Source -> Tuple StorableType -> Reify ()
node a@(Data _ _) fun inTup inType = do
    i <- newIndex
    remember a i
    tell [Node i fun inTup inType (dataType a)]
sourceNode :: Data a -> Graph.Function -> Reify ()
sourceNode a fun = node a fun (Tup []) (Tup [])
isPrimitive :: Data a -> Bool
isPrimitive a@(Data _ _) = case dataType a of
    One (StorableType [] _) -> True
    _ -> False
source :: [Int] -> Data a -> Reify Source
source path a = case dataToExpr a of
    Application (Function ('g':'e':'t':'T':'u':'p':_:n:_) _) tup ->
      source ((read [n]  1) : path) tup
        
    Value b | isPrimitive a ->
      let PrimitiveData b' = storableData b
       in return $ Constant b'
    _ -> do
      Just i <- checkNode a
      return $ Graph.Variable (i,path)
traceTuple :: Data a -> Reify (Tuple Source)
traceTuple a = case dataToExpr a of
    Application (Application (Function "tup2" _) b) c -> do
      b' <- traceTuple b
      c' <- traceTuple c
      return (Tup [b',c'])
    Application (Application (Application (Function "tup3" _) b) c) d -> do
      b' <- traceTuple b
      c' <- traceTuple c
      d' <- traceTuple d
      return (Tup [b',c',d'])
    Application (Application (Application (Application
                                            (Function "tup4" _) b) c) d) e -> do
      b' <- traceTuple b
      c' <- traceTuple c
      d' <- traceTuple d
      e' <- traceTuple e
      return (Tup [b',c',d',e'])
    _ -> liftM One (source [] a)
buildGraph :: forall a . Data a -> Reify ()
buildGraph a@(Data _ _) = do
    ia <- checkNode a
    unless (isJust ia) $ list (dataToExpr a)
  where
    funcNode fun inp = do
      buildGraph inp
      inTup <- traceTuple inp
      node a fun inTup (dataType inp)
    list :: Expr a -> Reify ()
    list Variable = sourceNode a Graph.Input
    list (Value b)
      | isPrimitive a = return ()
      | otherwise     = sourceNode a $ Graph.Array $ storableData b
    list (Application (Application (Function fun _) b) c)
      | fun == "tup2" = buildGraph b >> buildGraph c
    list (Application (Application (Application (Function "tup3" _) b) c) d) =
      buildGraph b >> buildGraph c >> buildGraph d
    list (Application (Application (Application (Application
                                               (Function "tup4" _) b) c) d) e) =
      buildGraph b >> buildGraph c >> buildGraph d >> buildGraph e
    list (Application (Function fun _) b)
      | take 6 fun == "getTup" = buildGraph b
      | otherwise              = funcNode (Graph.Function fun) b
      
    list (NoInline fun f b@(Data _ _)) = do
      iface <- buildSubFun (deref f)
      funcNode (Graph.NoInline fun iface) b
      
    list (IfThenElse c t e b@(Data _ _)) = do
      ifaceThen <- buildSubFun t
      ifaceElse <- buildSubFun e
      funcNode (Graph.IfThenElse ifaceThen ifaceElse) (tup2 c b)
    list (While cont body b@(Data _ _)) = do
      ifaceCont <- buildSubFun cont
      ifaceBody <- buildSubFun body
      funcNode (Graph.While ifaceCont ifaceBody) b
    list (Parallel l ixf) = do
      iface <- buildSubFun ixf
      funcNode (Graph.Parallel iface) l
buildSubFun :: forall a b . (Typeable a, Typeable b) =>
    (a :-> b) -> Reify Interface
buildSubFun (Lambda _ inp outp) = do
    let inType  = typeOf (dataSize inp) (T::T a)
        outType = typeOf (dataSize outp) (T::T b)
    buildGraph inp  
    buildGraph outp
    outTup <- traceTuple outp
    info   <- get
    let inId = visited info Map.! dataId inp
    return (Interface inId outTup inType outType)
reifyD :: (Typeable a, Typeable b) => (Data a -> Data b) -> Graph
reifyD f = Graph nodes iface
  where
    subFun            = lambda universal f
    (iface,(nodes,_)) = runGraph (buildSubFun subFun) startInfo
class Program a
  where
    
    reify :: a -> Graph
    
    
    
    
    numArgs :: T a -> Int
instance Computable a => Program a
  where
    reify     = reify_computable
    numArgs _ = 0
instance (Computable a, Computable b) => Program (a,b)
  where
    reify     = reify_computable
    numArgs _ = 0
instance (Computable a, Computable b, Computable c) => Program (a,b,c)
  where
    reify     = reify_computable
    numArgs _ = 0
instance (Computable a, Computable b, Computable c, Computable d) => Program (a,b,c,d)
  where
    reify     = reify_computable
    numArgs _ = 0
instance (Computable a, Computable b) => Program (a -> b)
  where
    reify   = reifyD . lowerFun
    numArgs = const 1
instance (Computable a, Computable b, Computable c) => Program (a -> b -> c)
  where
    reify f = reifyD $ lowerFun $ \(a,b) -> f a b
    numArgs = const 2
instance (Computable a, Computable b, Computable c, Computable d) => Program (a -> b -> c -> d)
  where
    reify f = reifyD $ lowerFun $ \(a,b,c) -> f a b c
    numArgs = const 3
instance (Computable a, Computable b, Computable c, Computable d, Computable e) => Program (a -> b -> c -> d -> e)
  where
    reify f = reifyD $ lowerFun $ \(a,b,c,d) -> f a b c d
    numArgs = const 4
reify_computable :: forall a . Computable a => a -> Graph
reify_computable a =
    reifyD (const (internalize a) :: Data () -> Data (Internal a))
showCore :: forall a . Program a => a -> String
showCore = showGraph False "program" (numArgs (T::T a) > 0) . reify
showCoreWithSize :: forall a . Program a => a -> String
showCoreWithSize = showGraph True "program" (numArgs (T::T a) > 0) . reify
printCore :: Program a => a -> IO ()
printCore = putStrLn . showCore
printCoreWithSize :: Program a => a -> IO ()
printCoreWithSize = putStrLn . showCoreWithSize
instance Storable a => Show (Data a) where
  show = showCore