module Lava.Misc where



import Control.Monad
import qualified Data.Foldable as Fold
import qualified Data.Map as Map
import Test.QuickCheck (forAll, choose, quickCheck)

import Data.Hardware.Internal
import Lava.Model
import Lava.Loop
import Lava.Port
import Lava.Interpret

import qualified "chalmers-lava2000" Lava as L
import qualified "chalmers-lava2000" Lava.Ref as L



input :: forall lib m p . (MonadLava lib m, PortFixed p Signal) => m p
input = liftM fromListFP $ replicateM (result (lengthFP::Res p Int)) inputSig
  -- Declare a primary input

inputList :: (MonadLava lib m, PortFixed p Signal) => Int -> m [p]
inputList n = replicateM n input



cell
    :: forall m lib pi po
     . (MonadLava lib m, PortFixed pi Signal, PortFixed po Signal)
    => lib -> pi -> m po

cell cid pi = liftM fromListFP $ cellList cid ins
  where
    ins = Fold.toList (port pi)
  -- Declare a cell



sourceCell :: (MonadLava lib m, PortFixed p Signal) => lib -> m p
sourceCell cid = liftM fromListFP $ cellList cid []
  -- A cell which has no inputs.

sinkCell :: (MonadLava lib m, PortFixed p Signal) => lib -> p -> m ()
sinkCell cid p = do
    cellList cid $ Fold.toList $ port p
    return ()
  -- A cell which has no outputs.

physCell :: MonadLava lib m => lib -> (a -> m a)
physCell cid a = cellList cid [] >> return a
  -- A cell which has neither inputs nor outputs (purely physical).



label :: (MonadLava lib m, PortStruct p Signal t) => Tag -> p -> m p
label = mapPortM . labelSig
  -- Declare a label; only side-effect important.



toLava2000
    :: ( MonadLava lib m
       , PortStruct pli (L.Signal Bool) ti
       , PortStruct psi Signal          ti
       , PortStruct pso Signal          to
       , PortStruct plo (L.Signal Bool) to
       )
    => (psi -> m pso) -> (pli -> plo)

toLava2000 circ = fst . interpretFunc lava2000Interp (toLava . circ)



simulateSeq
    :: ( MonadLava lib m
       , PortStruct pni Int    ti
       , PortStruct psi Signal ti
       , PortStruct pso Signal to
       , PortStruct pno Int    to
       )
    => (psi -> m pso) -> ([pni] -> [pno])

simulateSeq circ
    = map (unport . fmap sigToInt)
    . L.simulateSeq circP
    . map (fmap intToSig . port)
  where
    circP = fst . interpretFuncP lava2000Interp
        (liftM port . toLava . circ . unport)

    intToSig 0 = L.low
    intToSig 1 = L.high
    intToSig _ = error "Only values 0 and 1 allowed"

    sigToInt (L.Signal (L.Symbol r)) = case L.deref r of
        L.Bool False -> 0
        L.Bool True  -> 1



simulate
    :: ( MonadLava lib m
       , PortStruct pni Int    ti
       , PortStruct psi Signal ti
       , PortStruct pso Signal to
       , PortStruct pno Int    to
       )
    => (psi -> m pso) -> (pni -> pno)

simulate circ = head . simulateSeq circ . return



-- |@encodeBin n x@
--
-- Encodes the number @x@ as a binary number of length @n@. The resulting list
-- contains only zeroes and ones.
encodeBin :: Int -> Int -> [Int]
encodeBin n i
    | l < n     = ie ++ replicate (n-l) 0
    | otherwise = take n ie
  where
    ie = encode i
    l  = length ie

    encode i
        | i <= 1    = [i]
        | otherwise = m : encode d
      where
        (d,m) = i `divMod` 2



decodeBin :: [Int] -> Int
decodeBin as = sum [2^x * check a | (a,x) <- as `zip` [0..]]
  where
    check a
      | a `elem` [0,1] = a
      | otherwise      = error "decodeBin: Not a binary number"

prop_encodeBin =
    forAll (choose (0,10)) $ \n ->
      forAll (choose (0,2^11)) $ \i ->
        length (encodeBin n i) == n

prop_encodeDecodeBin =
    forAll (choose (0,10)) $ \n ->
      forAll (choose (0,2^11)) $ \i ->
        decodeBin (encodeBin n i) == (i `mod` 2^n)

checkAll = do
    quickCheck prop_encodeBin
    quickCheck prop_encodeDecodeBin



verify
    :: forall lib m ps
     . (MonadLava lib m, PortFixed ps Signal)
    => (ps -> m Signal) -> IO ()

verify circ = L.smv (L.forAll (L.list n) circP) >> return ()
  where
    n     = result (lengthFP :: Res ps Int)
    circP = toLava2000 (circ . fromListFP)



depth :: (MonadLava lib m, PortStruct ps Signal t, PortStruct pd Int t) =>
    m ps -> (pd, InterpDesignDB lib Int)
depth circ
    | hasLoopDB True db = error "depth: Combinational feedback loop"
    | otherwise         = pd_idb
  where
    pd_idb@(_,(db,_)) = interpret depthInterp (toLava circ)

fanout :: MonadLava lib m => m a -> InterpDesignDB lib Int
fanout circ = (db, fmap length $ fanoutDB db)
  where
    (_,db) = runLava (toLava circ)

size :: MonadLava lib m => m p -> Int
size = length . Map.toList . cellDB . snd . runLava . toLava