{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
module Copilot.Theorem.TransSys.Transform
( mergeNodes
, inline
, removeCycles
, complete
) where
import Copilot.Theorem.TransSys.Spec
import Copilot.Theorem.TransSys.Renaming
import Copilot.Theorem.Misc.Utils
import Control.Monad (foldM, forM_, forM, guard)
import Data.List (sort, (\\), intercalate, partition)
import Control.Exception.Base (assert)
import Data.Map (Map, (!))
import Data.Set (member)
import Data.Bimap (Bimap)
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.Graph as Graph
import qualified Data.Bimap as Bimap
prefix :: String -> Var -> Var
prefix s1 (Var s2) = Var $ s1 ++ "." ++ s2
ncNodeIdSep = "-"
mergeNodes :: [NodeId] -> TransSys -> TransSys
mergeNodes toMergeIds spec =
spec
{ specNodes = newNode :
map (updateOtherNode newNodeId toMergeIds renamingExtF) otherNodes
, specProps = Map.map renamingExtF (specProps spec) }
where
nodes = specNodes spec
(toMerge, otherNodes) = partition ((`elem` toMergeIds) . nodeId) nodes
newNodeId
| specTopNodeId spec `elem` toMergeIds = specTopNodeId spec
| otherwise = intercalate ncNodeIdSep (sort toMergeIds)
newNode = Node
{ nodeId = newNodeId
, nodeDependencies = dependencies
, nodeImportedVars = importedVars
, nodeLocalVars = localVars
, nodeConstrs = constrs }
dependencies = nub'
[ id |
n <- toMerge
, id <- nodeDependencies n
, id `notElem` toMergeIds ]
(importedVars, renamingF) = runRenaming $ do
renameLocalVars toMerge
redirectLocalImports toMerge
selectImportedVars toMerge otherNodes dependencies
localVars = mergeVarsDescrs toMerge renamingF
renamingExtF (gv@(ExtVar nId _))
| nId `elem` toMergeIds = ExtVar newNodeId (renamingF gv)
| otherwise = gv
constrs = mergeConstrs toMerge renamingF
updateOtherNode :: NodeId -> [NodeId] -> (ExtVar -> ExtVar) -> Node -> Node
updateOtherNode newNodeId mergedNodesIds renamingF n = n
{ nodeDependencies =
let ds = nodeDependencies n
ds' = ds \\ mergedNodesIds
in if length ds' < length ds then newNodeId : ds' else ds
, nodeImportedVars =
Bimap.fromList [ (lv, renamingF gv)
| (lv, gv) <- Bimap.toList $ nodeImportedVars n ]
}
updateExpr :: NodeId -> (ExtVar -> Var) -> Expr t -> Expr t
updateExpr nId renamingF = transformExpr aux
where
aux :: forall t. Expr t -> Expr t
aux (VarE t v) = VarE t (renamingF (ExtVar nId v))
aux e = e
mergeVarsDescrs :: [Node] -> (ExtVar -> Var) -> Map Var VarDescr
mergeVarsDescrs toMerge renamingF = Map.fromList $ do
n <- toMerge
let nId = nodeId n
(v, VarDescr t def) <- Map.toList $ nodeLocalVars n
let d' = case def of
Pre val v' -> VarDescr t $
Pre val $ renamingF (ExtVar nId v')
Expr e -> VarDescr t $
Expr $ updateExpr nId renamingF e
Constrs cs -> VarDescr t $
Constrs $ map (updateExpr nId renamingF) cs
return (renamingF $ ExtVar nId v, d')
mergeConstrs :: [Node] -> (ExtVar -> Var) -> [Expr Bool]
mergeConstrs toMerge renamingF =
[ updateExpr (nodeId n) renamingF c | n <- toMerge, c <- nodeConstrs n ]
renameLocalVars :: [Node] -> Renaming ()
renameLocalVars toMerge =
forM_ niVars $ \(n, v) -> do
v' <- getFreshName [n `prefix` v]
rename n v v'
where
niVars = [ (nodeId n, v)
| n <- toMerge, (v, _) <- Map.toList (nodeLocalVars n) ]
selectImportedVars :: [Node] -> [Node] -> [NodeId]
-> Renaming (Bimap Var ExtVar)
selectImportedVars toMerge otherNodes dependencies =
foldM checkImport Bimap.empty depsVars
where
otherNodesMap = Map.fromList [(nodeId n, n) | n <- otherNodes]
depsVars = [ (nId, v)
| nId <- dependencies, let n = otherNodesMap ! nId
, v <- Map.keys (nodeLocalVars n)]
checkImport acc (nId, v) = do
v' <- getFreshName [nId `prefix` v]
bmap <- forM toMerge $ \n' ->
case Bimap.lookupR (ExtVar nId v)
(nodeImportedVars n') of
Just lv -> rename (nodeId n') lv v' >> return True
Nothing -> return False
return $
if True `elem` bmap
then Bimap.insert v' (ExtVar nId v) acc
else acc
redirectLocalImports :: [Node] -> Renaming ()
redirectLocalImports toMerge = do
renamingF <- getRenamingF
forM_ x $ \(n, alias, n', v) ->
rename n alias (renamingF (ExtVar n' v))
where
mergedNodesSet = Set.fromList [nodeId n | n <- toMerge]
x = do
n <- toMerge
let nId = nodeId n
(alias, ExtVar n' v) <- Bimap.toList (nodeImportedVars n)
guard $ n' `member` mergedNodesSet
return (nId, alias, n', v)
inline :: TransSys -> TransSys
inline spec = mergeNodes [nodeId n | n <- specNodes spec] spec
removeCycles :: TransSys -> TransSys
removeCycles spec =
topoSort $ foldr mergeComp spec (buildScc nodeId $ specNodes spec)
where
mergeComp (Graph.AcyclicSCC _) s = s
mergeComp (Graph.CyclicSCC ids) s = mergeNodes ids s
buildScc nrep ns =
let depGraph = map (\n -> (nrep n, nodeId n, nodeDependencies n)) ns
in Graph.stronglyConnComp depGraph
topoSort s = s { specNodes =
map (\(Graph.AcyclicSCC n) -> n) $ buildScc id (specNodes s) }
complete :: TransSys -> TransSys
complete spec =
assert (isTopologicallySorted spec) $ spec { specNodes = specNodes' }
where
specNodes' =
reverse
. foldl completeNode []
. specNodes
. completeTopNodeDeps
$ spec
completeTopNodeDeps spec = spec { specNodes = map aux nodes }
where
nodes = specNodes spec
aux n
| nodeId n == specTopNodeId spec =
n { nodeDependencies = map nodeId nodes \\ [nodeId n] }
| otherwise = n
completeNode :: [Node] -> Node -> [Node]
completeNode ns n = (n { nodeDependencies = dependencies'
, nodeImportedVars = importedVars' }) : ns
where
nsMap = Map.fromList [(nodeId n, n) | n <- ns]
dependencies' =
let newDeps = do
dId <- nodeDependencies n
let d = nsMap ! dId
nodeDependencies d
in nub' $ nodeDependencies n ++ newDeps
importedVars' = fst . runRenaming $ do
forM_ (Set.toList $ nodeVarsSet n) addReservedName
let toImportVars = nub' [ ExtVar nId v
| nId <- dependencies'
, let n' = nsMap ! nId
, v <- Map.keys (nodeLocalVars n') ]
tryImport acc ev@(ExtVar n' v) = do
let preferedName
| head ncNodeIdSep `elem` n' = v
| otherwise = n' `prefix` v
alias <- getFreshName [preferedName, n' `prefix` v]
return $ Bimap.tryInsert alias ev acc
foldM tryImport (nodeImportedVars n) toImportVars