module Lava.Port where



import Control.Applicative
import Control.Monad
import Data.Foldable (Foldable)
import qualified Data.Foldable as Fold
import Data.List as List
import Data.Traversable (Traversable, traverse)
import qualified Data.Traversable as Trav

import Data.Hardware.Internal
import Lava.Model

import qualified "chalmers-lava2000" Lava as L



data PortTree s
       = One  {unOne :: s}
       | List [PortTree s]
     deriving (Eq, Show)

instance Functor PortTree
  where
    fmap f (One s)   = One (f s)
    fmap f (List ps) = List $ map (fmap f) ps

instance Foldable PortTree
  where
    foldr f x (One s)   = f s x
    foldr f x (List ps) = Fold.foldr (flip $ Fold.foldr f) x ps

instance Traversable PortTree
  where
    traverse f (One s)   = pure One  <*> f s
    traverse f (List ps) = pure List <*> traverse (traverse f) ps



class Port p s | p -> s
  where
    port   :: p -> PortTree s
    unport :: PortTree s -> p

instance Port Signal Signal
  where
    port   = One
    unport = unOne

instance Port () ()
  where
    port   = One
    unport = unOne

instance Port Bool Bool
  where
    port   = One
    unport = unOne

instance Port Int Int
  where
    port   = One
    unport = unOne

instance Port (L.Signal Bool) (L.Signal Bool)
  where
    port   = One
    unport = unOne

instance Port p s => Port (Maybe p) s
  where
    port (Just p) = List [port p]
    port _        = List []

    unport (List [p]) = Just (unport p)
    unport _          = Nothing

instance (Port p1 s, Port p2 s) => Port (Either p1 p2) s
  where
    port (Left  p) = List [List [port p], List []]
    port (Right p) = List [List [], List [port p]]

    unport (List [List [p], List []]) = Left  (unport p)
    unport (List [List [], List [p]]) = Right (unport p)

instance Port p s => Port [p] s
  where
    port             = List . map port
    unport (List ps) = map unport ps

instance (Port p1 s, Port p2 s) => Port (p1,p2) s
  where
    port (p1,p2)          = List [port p1, port p2]
    unport (List [p1,p2]) = (unport p1, unport p2)

instance (Port p1 s, Port p2 s, Port p3 s) => Port (p1,p2,p3) s
  where
    port (p1,p2,p3)          = List [port p1, port p2, port p3]
    unport (List [p1,p2,p3]) = (unport p1, unport p2, unport p3)

instance (Port p1 s, Port p2 s, Port p3 s, Port p4 s) => Port (p1,p2,p3,p4) s
  where
    port (p1,p2,p3,p4)          = List [port p1, port p2, port p3, port p4]
    unport (List [p1,p2,p3,p4]) = (unport p1, unport p2, unport p3, unport p4)



class Port p s => PortStruct p s t | p -> s t, s t -> p

instance PortStruct Signal          Signal          ()
instance PortStruct ()              ()              ()
instance PortStruct Bool            Bool            ()
instance PortStruct Int             Int             ()
instance PortStruct (L.Signal Bool) (L.Signal Bool) ()

instance PortStruct p s t => PortStruct (Maybe p) s (Maybe t)

instance (PortStruct p1 s t1, PortStruct p2 s t2)
      => PortStruct (Either p1 p2) s (Either t1 t2)

instance PortStruct p s t => PortStruct [p] s [t]

instance (PortStruct p1 s t1, PortStruct p2 s t2)
      => PortStruct (p1,p2) s (t1,t2)

instance (PortStruct p1 s t1, PortStruct p2 s t2, PortStruct p3 s t3)
      => PortStruct (p1,p2,p3) s (t1,t2,t3)

instance ( PortStruct p1 s t1
         , PortStruct p2 s t2
         , PortStruct p3 s t3
         , PortStruct p4 s t4
         )
      => PortStruct (p1,p2,p3,p4) s (t1,t2,t3,t4)



mapPort :: (PortStruct pa sa t, PortStruct pb sb t) => (sa -> sb) -> (pa -> pb)
mapPort f = unport . fmap f . port

mapPortM
    :: (PortStruct pa sa t, PortStruct pb sb t, Monad m)
    => (sa -> m sb) -> (pa -> m pb)
mapPortM f = liftM unport . Trav.mapM f . port



class Port p s => PortFixed p s | p -> s
  where
    lengthFP   :: Res p Int
    fromListFP :: [s] -> p

instance PortFixed Signal Signal
  where
    lengthFP       = R 1
    fromListFP [s] = s

instance (PortFixed p1 s, PortFixed p2 s) => PortFixed (p1,p2) s
  where
    lengthFP = R $ result (lengthFP::Res p1 Int) + result (lengthFP::Res p2 Int)

    fromListFP ss = (fromListFP ss1, fromListFP ss2)
      where
        (ss1,ss2) = splitAt (result (lengthFP::Res p1 Int)) ss

instance ( PortFixed p1 s
         , PortFixed p2 s
         , PortFixed p3 s
         )
      => PortFixed (p1,p2,p3) s
  where
    lengthFP = R
        $ result (lengthFP::Res p1 Int)
        + result (lengthFP::Res p2 Int)
        + result (lengthFP::Res p3 Int)

    fromListFP ss = (fromListFP ss1, fromListFP ss2, fromListFP ss3)
      where
        (ss1,ss23) = splitAt (result (lengthFP::Res p1 Int)) ss
        (ss2,ss3)  = splitAt (result (lengthFP::Res p2 Int)) ss23

instance ( PortFixed p1 s
         , PortFixed p2 s
         , PortFixed p3 s
         , PortFixed p4 s
         )
      => PortFixed (p1,p2,p3,p4) s
  where
    lengthFP = R
        $ result (lengthFP::Res p1 Int)
        + result (lengthFP::Res p2 Int)
        + result (lengthFP::Res p3 Int)
        + result (lengthFP::Res p4 Int)

    fromListFP ss =
        (fromListFP ss1, fromListFP ss2, fromListFP ss3, fromListFP ss4)
      where
        (ss1,ss234) = splitAt (result (lengthFP::Res p1 Int)) ss
        (ss2,ss34)  = splitAt (result (lengthFP::Res p2 Int)) ss234
        (ss3,ss4)   = splitAt (result (lengthFP::Res p3 Int)) ss34



instance L.Generic (PortTree (L.Signal Bool))
  where
    struct (One (L.Signal sym)) = L.Object sym
    struct (List ss)            = L.Compound (map L.struct ss)

    construct (L.Object sym)  = One (L.Signal sym)
    construct (L.Compound ss) = List (map L.construct ss)

  -- This Lava 2000 class corresponds roughly to the Port class.