{-# LANGUAGE ScopedTypeVariables #-} 
-----------------------------------------------------------------------------
-- |
-- Module      :  ForSyDe.Backend.Simulate
-- Copyright   :  (c) SAM Group, KTH/ICT/ECS 2007-2008
-- License     :  BSD-style (see the file LICENSE)
-- 
-- Maintainer  :  forsyde-dev@ict.kth.se
-- Stability   :  experimental
-- Portability :  non-portable (Template Haskell, LSTV)
--
-- This module provides the simulation backend of ForSyDe's embedded compiler
--
-- /This module is based on Lava2000/: <http://www.cs.chalmers.se/~koen/Lava/>
--
-----------------------------------------------------------------------------
module ForSyDe.Backend.Simulate (simulate) where

import ForSyDe.OSharing
import ForSyDe.Netlist
import ForSyDe.Netlist.Traverse
import ForSyDe.System.SysDef
import ForSyDe.System.SysFun(SysFunToSimFun(..))
import ForSyDe.ForSyDeErr
import ForSyDe.Process.ProcVal

import Data.Maybe (fromJust)
import Control.Monad.ST
import Data.STRef
import qualified Data.Traversable as DT
import Data.List (transpose)
import Data.Dynamic

-- | 'simulate' takes a system definition and generates a function 
--   able simulate a System using a list-based representation 
--   of its signals.
simulate :: SysFunToSimFun sysFun simFun => SysDef sysFun -> simFun
simulate sysDef = fromDynSimFun (simulateDyn (unSysDef sysDef)) []

-- FIXME: clean and document the following horrible code!

--------------------------------------
-- The following was adapted from Lava
--------------------------------------

type Var s
  = (STRef s Dynamic, STRef s (Wire s))

data Wire s
  = Wire
    { dependencies :: [Var s]
    , kick         :: ST s ()
    }

----------------------------------------------------------------
-- simulateDyn


simulateDyn :: PrimSysDef  -> [[Dynamic]] -> [[Dynamic]]
simulateDyn pSysDef inps | any null inps = replicate outN []
   where outN = (length . oIface . readURef . unPrimSysDef) pSysDef
simulateDyn pSysDef inps = runST (
  do let sysDefVal = (readURef . unPrimSysDef) pSysDef
         sysDefInIface = iIface sysDefVal
     -- List where to store the Vars generated by delay processes
     roots <- newSTRef []
     -- Input port ids paired with a reference to each input value list
     inpPairs <- zipWithM (\(id,_) inputL -> 
                            do {ref <- newSTRef inputL; return (id,ref)}) 
                         sysDefInIface inps

     let -- Add a Var to the roots list 
         root r =
           do rs <- readSTRef roots
              writeSTRef roots (r:rs)

         -- Create an empty var
         empty = do rval <- newSTRef (error "val?")
                    rwir <- newSTRef (error "wire?")
                    return (rval, rwir)
         new node = 
           do mapM (\tag -> do {e <- empty; return (tag, e)}) (outTags node)
         newInstance varPairs node =
           let funName = "ForSyDe.Backend.Simulate.simulateDyn" 
           in case node of
             InPort id ->
                 case lookup id varPairs of
                     -- FIXME: replace the Other error with a custom one
                     Nothing  -> intError funName (Other "inconsistency")
                     Just var -> return [(InPortOut, var)]
             _      -> new node
         
         -- define for the general traversal
         define  nodeVarPairs childVars = 
           case (nodeVarPairs,childVars) of
            ([(InPortOut, var)], InPort name) -> do
              let inputRef = fromJust $ lookup name inpPairs
              relate var [] $
                do (curr:rest) <- readSTRef inputRef
                   writeSTRef inputRef rest
                   return curr
                   
            _ -> defineShared nodeVarPairs childVars

         -- define for instances
         defineInstance nodeVarPairs childVars = 
           case (nodeVarPairs,childVars) of
            ([(InPortOut, _)], InPort _) -> return ()
            _ -> defineShared nodeVarPairs childVars
         
         -- Shared part of define define for instances and the main traversal
         defineShared  nodeVarPairs childVars = -- r s =
           case (nodeVarPairs,childVars) of
            ([(InPortOut, _)], InPort _) -> return ()

            (nodeVarPairs,
             Proc _ (SysIns pSysDef ins)) ->
               -- FIXME: ugly ugly ugly
               do let sysDefVal = (readURef . unPrimSysDef) pSysDef
                      taggedIns = zipWith (\(id,_) var -> (id,var)) 
                                          (iIface sysDefVal) ins
                  sr  <- traverseST 
                           (newInstance taggedIns)
                           defineInstance 
                           (netlist sysDefVal)

                  let relateIns prevVar@(prevValR,_) (_,nextVar) =
                         relate nextVar [prevVar] (readSTRef prevValR)
 
                  zipWithM_ relateIns sr nodeVarPairs

            ([(DelaySYOut, nodeVar)], 
             Proc _ (DelaySY (ProcVal init _) sigVar)) ->
               do valVar <- empty
                  relate valVar [] (return init)                    
                  delay nodeVar valVar sigVar
            _ ->
              do let evalPairs = eval `fmap` DT.mapM (readSTRef.fst) childVars
                     args = arguments childVars
                     relEval (tag, var) = 
                         relate var args $
                            -- FIXME: remove fromJust and write a proper error
                            liftM (fromJust.(lookup tag)) evalPairs
                 mapM_ relEval  nodeVarPairs  
          where
           delay r ri@(rinit,_) r1@(pre,_) =
               do state <- newSTRef Nothing
                  r2 <- empty
                  root r2

                  relate r [ri] $
                    do ms <- readSTRef state
                       case ms of
                         Just s  -> return s
                         Nothing ->
                           do s <- readSTRef rinit
                              writeSTRef state (Just s)
                              return s

                  relate r2 [r,r1] $
                    do s <- readSTRef pre
                       writeSTRef state (Just s)
                       return s
     
     sr   <- traverseST new define (netlist sysDefVal)
     rs   <- readSTRef roots
     -- remove tags of the resulting vars (all the root nodes should only
     -- have one output and thus a must return a unique list)
     step <- drive (sr ++ rs)

     outs <- lazyloop $
       do step
          s <- DT.mapM (readSTRef . fst) sr
          return s
     -- Since the simulation is done in a per-cycle basis
     -- the results (outs) are transposed (not what we want)
     -- e.g.
     -- imagine a system whose outputs are its inputs plus 1
     -- then, for this these two inputs [[1,2,3],[4,5,6]]
     -- outs would be [[2,5],[3,6],[4,7]]
     --
     -- We need as well to check that all inputs are defined in
     -- each simulation cycle e.g. [[1,2,3],[4,5,6,7]] as input
     -- makes imposible to simulate cycle 4
     --
     -- NOTE: having to check this makes simulation really inneficient
     -- a solution would providing a cycle-based simulation input/output
     -- which wouldn't suffer from this problems
     -- Or, even better, a simulation which showed a diffierent type of output 
     let inN = length sysDefInIface
         res = if inN == 0 then
                 transpose outs
               else transpose (checkIns inN (transpose inps) outs)
     return res
  )

   
-- evaluation order

relate :: Var s -> [Var s] -> ST s Dynamic -> ST s ()
relate (rval, rwir) rs f =
  do writeSTRef rwir $ 
       Wire{ dependencies = rs
           , kick = do b <- f
                       writeSTRef rval b
           }

drive :: [Var s] -> ST s (ST s ())
drive [] =
  do return (return ())

drive ((_,rwir):rs) =
  do wire <- readSTRef rwir
     writeSTRef rwir (error "detected combinational loop")
     driv1 <- drive (dependencies wire)
     writeSTRef rwir $
       Wire { dependencies = [], kick = return () }
     driv2 <- drive rs
     return $
       do driv1
          kick wire
          driv2

----------------------------------------------------------------
-- helper functions

lazyloop :: ST s a -> ST s [a]
lazyloop m = 
  do a  <- m
     as <- unsafeInterleaveST (lazyloop m)
     return (a:as)

-- | check that there will only be output as long as there are inputs 
checkIns :: Int -- ^ number of inputs 
         -> [[a]] -- ^ transposed inputs 
         -> [[b]] -- ^ transposed outputs (infinitie list) 
         -> [[b]] -- ^ selected outputs

-- The lazy pattern match is used to avoid evaluating the output list 
-- if length i /= nIns. If that happens the input lists of simulate will 
-- implicitly be simulated, and due to lack of inputs it will cause an error. 
checkIns nIns (i:is) ~(o:os) |  length i == nIns = o : checkIns nIns is os  
checkIns _ _ _ = []