module Lava.Interpret where



import Control.Arrow ((***))
import Control.Monad.State
import qualified Data.Foldable as Fold
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Traversable as Trav

import Data.Hardware.Internal
import Data.Logical.Knot
import Lava.Model
import Lava.Port



askSig :: Interpretation lib x -> Signal -> Knot Signal x x
askSig interp = askKnotDef (defaultVal interp)

tellSigs :: Interpretation lib x -> [Signal] -> [Maybe x] -> Knot Signal x ()
tellSigs interp sigs vals = sequence_ [sig*=x | (sig, Just x) <- zip sigs vals]



interpretCells :: forall lib x
     . CellLibrary lib
    => Interpretation lib x
    -> [(Signal, x)]
    -> [(CellId, (lib,[Signal]))]
    -> Map Signal x

interpretCells interp es cells = snd $ accKnot (accumulator interp) $ do

    sequence_ [s*=x | (s,x) <- es]
      -- Constrain explicitly interpreted signals.

    forM_ cells $ \(cid,(ct,ins)) -> do
        let sigs = cellOutputs cid ct ++ ins
        vals <- mapM (askSig interp) sigs
        tellSigs interp sigs $ propagator interp ct vals
      -- Propagate values across each cell.

  -- es is a list of explicit signal interpretations. The signals mentioned in
  -- this list must be valid according to prop_validSignals.



interpret__
    :: CellLibrary lib
    => Interpretation lib x
    -> [(Signal, x)]
    -> (PortTree Signal, DesignDB lib)
    -> (PortTree x, InterpDesignDB lib x)

interpret__ interp es (ps,db) = (fmap (sigMap Map.!) ps, (db,sigMap))
  where
    sigMap = interpretCells interp es (Map.toList $ cellDB db)



interpret_
    :: CellLibrary lib
    => Interpretation lib x
    -> [(Signal, x)]
    -> Lava lib (PortTree Signal)
    -> (PortTree x, InterpDesignDB lib x)

interpret_ interp es lava = interpret__ interp es (runLava lava)



interpret
    :: ( CellLibrary lib
       , PortStruct ps Signal t
       , PortStruct px x      t
       )
    => Interpretation lib x -> Lava lib ps -> (px, InterpDesignDB lib x)

interpret interp = (unport *** id) . interpret_ interp [] . liftM port



inputToSig :: PortTree x -> PortTree Signal
inputToSig = flip evalState (-1) . Trav.mapM toSig
  where
    toSig x = do
      iid <- get
      put (pred iid)
      return (PrimInpSig iid)
  -- Using negative indices to aviod clash with user-defined primary inputs.



interpretFuncP
    :: CellLibrary lib
    => Interpretation lib x
    -> (PortTree Signal -> Lava lib (PortTree Signal))
    -> (PortTree x -> (PortTree x, InterpDesignDB lib x))

interpretFuncP interp fs pxi = interpret_ interp es (fs psi)
  where
    psi = inputToSig pxi
    es  = Fold.toList psi `zip` Fold.toList pxi

  -- Note that the signals in psi will not be present in db in interpret_, so
  -- technically the database may not be valid. It would be possible to pass
  -- them separately and add to the database, but there's no point in doing
  -- that, since interpret_ only cares about the cells in db.



interpretFunc
    :: ( CellLibrary lib
       , PortStruct pxi x      ti
       , PortStruct psi Signal ti
       , PortStruct pso Signal to
       , PortStruct pxo x      to
       )
    => Interpretation lib x
    -> (psi -> Lava lib pso)
    -> (pxi -> (pxo, InterpDesignDB lib x))

interpretFunc interp f = (unport *** id) . interpretFuncP interp fP . port
  where
    fP = liftM port . f . unport