module Circus.Simplify (simplify) where

import           Circus.Types
import           Control.Arrow
import           Data.List
import qualified Data.Map as M
import           Data.Maybe
import           Data.Set (Set)
import qualified Data.Set as S

------------------------------------------------------------------------------
-- | Gather sets of input and output ports for the given cell.
cellPorts :: Cell -> (Set PortName, Set PortName)
cellPorts :: Cell -> (Set PortName, Set PortName)
cellPorts Cell
c =
  let ports :: [(PortName, Direction)]
ports = Map PortName Direction -> [(PortName, Direction)]
forall k a. Map k a -> [(k, a)]
M.assocs (Map PortName Direction -> [(PortName, Direction)])
-> Map PortName Direction -> [(PortName, Direction)]
forall a b. (a -> b) -> a -> b
$ Cell -> Map PortName Direction
cellPortDirections Cell
c
      ([PortName]
ip, [PortName]
op) = ((PortName, Direction) -> PortName)
-> [(PortName, Direction)] -> [PortName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PortName, Direction) -> PortName
forall a b. (a, b) -> a
fst ([(PortName, Direction)] -> [PortName])
-> ([(PortName, Direction)] -> [PortName])
-> ([(PortName, Direction)], [(PortName, Direction)])
-> ([PortName], [PortName])
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** ((PortName, Direction) -> PortName)
-> [(PortName, Direction)] -> [PortName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PortName, Direction) -> PortName
forall a b. (a, b) -> a
fst (([(PortName, Direction)], [(PortName, Direction)])
 -> ([PortName], [PortName]))
-> ([(PortName, Direction)], [(PortName, Direction)])
-> ([PortName], [PortName])
forall a b. (a -> b) -> a -> b
$ ((PortName, Direction) -> Bool)
-> [(PortName, Direction)]
-> ([(PortName, Direction)], [(PortName, Direction)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Direction -> Direction -> Bool
forall a. Eq a => a -> a -> Bool
== Direction
Input) (Direction -> Bool)
-> ((PortName, Direction) -> Direction)
-> (PortName, Direction)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PortName, Direction) -> Direction
forall a b. (a, b) -> b
snd) [(PortName, Direction)]
ports
   in ([PortName] -> Set PortName
forall a. Ord a => [a] -> Set a
S.fromList [PortName]
ip, [PortName] -> Set PortName
forall a. Ord a => [a] -> Set a
S.fromList [PortName]
op)


------------------------------------------------------------------------------
-- | Gather input and output bits for the given cell.
ioBits :: Module -> (Set Bit, Set Bit)
ioBits :: Module -> (Set Bit, Set Bit)
ioBits Module
m = ((Cell -> (Set Bit, Set Bit))
 -> Map CellName Cell -> (Set Bit, Set Bit))
-> Map CellName Cell
-> (Cell -> (Set Bit, Set Bit))
-> (Set Bit, Set Bit)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Cell -> (Set Bit, Set Bit))
-> Map CellName Cell -> (Set Bit, Set Bit)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Module -> Map CellName Cell
moduleCells Module
m) ((Cell -> (Set Bit, Set Bit)) -> (Set Bit, Set Bit))
-> (Cell -> (Set Bit, Set Bit)) -> (Set Bit, Set Bit)
forall a b. (a -> b) -> a -> b
$ \Cell
c ->
  let (Set PortName
ip, Set PortName
op) = Cell -> (Set PortName, Set PortName)
cellPorts Cell
c
      ib :: [Bit]
ib = (PortName -> [Bit]) -> Set PortName -> [Bit]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([Bit] -> Maybe [Bit] -> [Bit]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Bit] -> [Bit])
-> (PortName -> Maybe [Bit]) -> PortName -> [Bit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PortName -> Map PortName [Bit] -> Maybe [Bit])
-> Map PortName [Bit] -> PortName -> Maybe [Bit]
forall a b c. (a -> b -> c) -> b -> a -> c
flip PortName -> Map PortName [Bit] -> Maybe [Bit]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Cell -> Map PortName [Bit]
cellConnections Cell
c)) (Set PortName -> [Bit]) -> Set PortName -> [Bit]
forall a b. (a -> b) -> a -> b
$ Set PortName
ip
      ob :: [Bit]
ob = (PortName -> [Bit]) -> Set PortName -> [Bit]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ([Bit] -> Maybe [Bit] -> [Bit]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Bit] -> [Bit])
-> (PortName -> Maybe [Bit]) -> PortName -> [Bit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PortName -> Map PortName [Bit] -> Maybe [Bit])
-> Map PortName [Bit] -> PortName -> Maybe [Bit]
forall a b c. (a -> b -> c) -> b -> a -> c
flip PortName -> Map PortName [Bit] -> Maybe [Bit]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Cell -> Map PortName [Bit]
cellConnections Cell
c)) (Set PortName -> [Bit]) -> Set PortName -> [Bit]
forall a b. (a -> b) -> a -> b
$ Set PortName
op
      mod_ports :: [(Direction, [Bit])]
mod_ports = (Port -> (Direction, [Bit])) -> [Port] -> [(Direction, [Bit])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Port -> Direction
portDirection (Port -> Direction)
-> (Port -> [Bit]) -> Port -> (Direction, [Bit])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Port -> [Bit]
portBits) ([Port] -> [(Direction, [Bit])]) -> [Port] -> [(Direction, [Bit])]
forall a b. (a -> b) -> a -> b
$ Map PortName Port -> [Port]
forall k a. Map k a -> [a]
M.elems (Map PortName Port -> [Port]) -> Map PortName Port -> [Port]
forall a b. (a -> b) -> a -> b
$ Module -> Map PortName Port
modulePorts Module
m
      ([Bit]
iports, [Bit]
oports) = ((Direction, [Bit]) -> [Bit]
forall a b. (a, b) -> b
snd ((Direction, [Bit]) -> [Bit]) -> [(Direction, [Bit])] -> [Bit]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) ([(Direction, [Bit])] -> [Bit])
-> ([(Direction, [Bit])] -> [Bit])
-> ([(Direction, [Bit])], [(Direction, [Bit])])
-> ([Bit], [Bit])
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** ((Direction, [Bit]) -> [Bit]
forall a b. (a, b) -> b
snd ((Direction, [Bit]) -> [Bit]) -> [(Direction, [Bit])] -> [Bit]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (([(Direction, [Bit])], [(Direction, [Bit])]) -> ([Bit], [Bit]))
-> ([(Direction, [Bit])], [(Direction, [Bit])]) -> ([Bit], [Bit])
forall a b. (a -> b) -> a -> b
$ ((Direction, [Bit]) -> Bool)
-> [(Direction, [Bit])]
-> ([(Direction, [Bit])], [(Direction, [Bit])])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Direction -> Direction -> Bool
forall a. Eq a => a -> a -> Bool
== Direction
Input) (Direction -> Bool)
-> ((Direction, [Bit]) -> Direction) -> (Direction, [Bit]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Direction, [Bit]) -> Direction
forall a b. (a, b) -> a
fst) [(Direction, [Bit])]
mod_ports
   in ([Bit] -> Set Bit
forall a. Ord a => [a] -> Set a
S.fromList ([Bit] -> Set Bit) -> [Bit] -> Set Bit
forall a b. (a -> b) -> a -> b
$ [Bit]
ib [Bit] -> [Bit] -> [Bit]
forall a. Semigroup a => a -> a -> a
<> [Bit]
oports, [Bit] -> Set Bit
forall a. Ord a => [a] -> Set a
S.fromList ([Bit] -> Set Bit) -> [Bit] -> Set Bit
forall a b. (a -> b) -> a -> b
$ [Bit]
ob [Bit] -> [Bit] -> [Bit]
forall a. Semigroup a => a -> a -> a
<> [Bit]
iports)


------------------------------------------------------------------------------
-- | Delete any cells which output only bits in the @to_kill@ set.
pruneCellsOutput :: Set Bit -> Module -> Module
pruneCellsOutput :: Set Bit -> Module -> Module
pruneCellsOutput Set Bit
to_kill Module
m = Module
m
  { moduleCells :: Map CellName Cell
moduleCells = (Cell -> Bool) -> Map CellName Cell -> Map CellName Cell
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (Cell -> Bool) -> Cell -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cell -> Bool
should_kill) (Map CellName Cell -> Map CellName Cell)
-> Map CellName Cell -> Map CellName Cell
forall a b. (a -> b) -> a -> b
$ Module -> Map CellName Cell
moduleCells Module
m
  }
  where
    should_kill :: Cell -> Bool
    should_kill :: Cell -> Bool
should_kill Cell
c =
      let (Set PortName
_, Set PortName
op) = Cell -> (Set PortName, Set PortName)
cellPorts Cell
c
       in case Set PortName -> [PortName]
forall a. Set a -> [a]
S.toList Set PortName
op of
            [PortName
pn] ->
              let bits :: [Bit]
bits = [Bit] -> Maybe [Bit] -> [Bit]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Bit] -> [Bit]) -> Maybe [Bit] -> [Bit]
forall a b. (a -> b) -> a -> b
$ PortName -> Map PortName [Bit] -> Maybe [Bit]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup PortName
pn (Map PortName [Bit] -> Maybe [Bit])
-> Map PortName [Bit] -> Maybe [Bit]
forall a b. (a -> b) -> a -> b
$ Cell -> Map PortName [Bit]
cellConnections Cell
c
               in (Bit -> Bool) -> [Bit] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Bit -> Set Bit -> Bool) -> Set Bit -> Bit -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip Bit -> Set Bit -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Set Bit
to_kill) [Bit]
bits
            [PortName]
_ -> Bool
False


------------------------------------------------------------------------------
-- | Recursively delete cells that output only bits which are unused in the
-- circuit.
simplify :: Module -> Module
simplify :: Module -> Module
simplify Module
m =
  let (Set Bit
ib, Set Bit
ob) = Module -> (Set Bit, Set Bit)
ioBits Module
m
      to_kill :: Set Bit
to_kill = Set Bit
ob Set Bit -> Set Bit -> Set Bit
forall a. Ord a => Set a -> Set a -> Set a
S.\\ Set Bit
ib
      m' :: Module
m' = Set Bit -> Module -> Module
pruneCellsOutput Set Bit
to_kill Module
m
   in case Module
m Module -> Module -> Bool
forall a. Eq a => a -> a -> Bool
== Module
m' of
        Bool
True -> Module
m
        Bool
False -> Module -> Module
simplify Module
m'