module Feldspar.Compiler.Transformation.Lifting where
import Feldspar.Core.Graph
import Feldspar.Core.Types hiding (typeOf)
import Feldspar.Compiler.Transformation.GraphUtils
import Data.List
replaceNoInlines:: HierarchicalGraph -> HierarchicalGraph
replaceNoInlines g = HierGraph (replaceNoInlinesHr (graphHierarchy g)) (hierGraphInterface g)
replaceNoInlinesHrList:: [Hierarchy]-> [Hierarchy]
replaceNoInlinesHrList hrlist = map replaceNoInlinesHr hrlist
replaceNoInlinesHr:: Hierarchy-> Hierarchy
replaceNoInlinesHr (Hierarchy hrlist) = Hierarchy (map replaceNoInlinesNode hrlist)
replaceNoInlinesNode:: (Node,[Hierarchy]) -> (Node,[Hierarchy])
replaceNoInlinesNode (n,hs) =
case function n of
NoInline name interface -> case replaceList of
[] -> (n,replaceNoInlinesHrList hs)
_ -> (nNew, replaceNoInlinesHrList hsNew)
where
replaceList = foldl (collectChangesHr (interfaceInput interface, hs)) (collectChangesInterface interface hs) hs
(nNew_, hsNew_) = changeInp (interfaceInput interface) replaceList (n,hs)
fullReplaceList = [((interfaceInput interface, []), inpVarsChange)] ++ (map fst replaceList)
(nNew, hsNew) = (nNew_{function= replaceVars fullReplaceList (function nNew_)}, map (replaceVars fullReplaceList) hsNew_)
_ -> (n,replaceNoInlinesHrList hs)
changeInp:: NodeId -> [((Variable, Variable -> Variable), Tuple StorableType)] -> (Node, [Hierarchy]) -> (Node, [Hierarchy])
changeInp inpNode chLs (node, hs) = (newNode, newHs)
where
newNode = Node (nodeId node)
(addIfcInpTypes newTyps (function node))
(addInps (map (fst . fst) chLs) (input node))
(addInpTypes newTyps (inputType node))
(outputType node)
newHs = map (addOutTypesHr inpNode newTyps) hs
newTyps = map snd chLs
addInps vars input = Tup ([input] ++ (map (One . Variable) vars))
addInpTypes types inpType = Tup ([inpType] ++ types)
addIfcInpTypes types (NoInline str ifc@(Interface {interfaceInputType = ifcType}))
= NoInline str ifc{interfaceInputType = Tup ([ifcType] ++ types)}
addOutTypesHr:: NodeId -> [Tuple StorableType] -> Hierarchy -> Hierarchy
addOutTypesHr id types (Hierarchy ndHrs) = Hierarchy (map (addOutTypesNode id types) ndHrs)
addOutTypesNode:: NodeId -> [Tuple StorableType] -> (Node, [Hierarchy]) -> (Node, [Hierarchy])
addOutTypesNode id types (node@(Node {nodeId = nId, outputType=outType}) ,hs)
= if (id == nId) then (node{outputType = Tup ([outType] ++ types)}, hs) else (node, hs)
collectChangesInterface :: Interface -> [Hierarchy] -> [((Variable, Variable -> Variable), Tuple StorableType)]
collectChangesInterface iface hs = map (genChange (interfaceInput iface)) $ zip [1..] $ filter ((mustChange hs) . fst) (tupleZipList (interfaceOutput iface, interfaceOutputType iface))
genChange:: NodeId -> (NodeId, (Source, StorableType)) -> ((Variable, Variable -> Variable), Tuple StorableType)
genChange inpId (index, (Variable (id, list), typ)) = (((id, list), varChange inpId index) , One typ)
mustChange:: [Hierarchy] -> Source -> Bool
mustChange hs x
= case x of
(Variable (id, list)) -> (notInHr id hs)
_ -> False
inpVarsChange:: Variable -> Variable
inpVarsChange (id,list) = (id, [0] ++ list)
varChange:: NodeId -> Int -> Variable -> Variable
varChange id index _ = (id, [index])
class CollectChangesHr a where
collectChangesHr:: (NodeId, [Hierarchy]) -> [((Variable, Variable -> Variable), Tuple StorableType)] -> a -> [((Variable, Variable -> Variable), Tuple StorableType)]
instance CollectChangesHr Hierarchy where
collectChangesHr nhs changesList (Hierarchy nodeHsList) = foldl (collectChangesHr nhs) changesList nodeHsList
instance CollectChangesHr (Node, [Hierarchy]) where
collectChangesHr nhs changesList (node, hsList) = foldl (collectChangesHr nhs) (collectChangesHr nhs changesList node) hsList
instance CollectChangesHr Node where
collectChangesHr nhs changesList node = collectChangesHr nhs (collectChangesHr nhs changesList (filter ((mustChange (snd nhs)) . fst) (tupleZipList (input node, inputType node)))) (function node)
instance CollectChangesHr [(Source,StorableType)] where
collectChangesHr (nodeId,hs) changesList sourceList = changesList ++ (map (genChange nodeId) $ zip [((length changesList) + 1)..] $ sourceList)
instance CollectChangesHr Function where
collectChangesHr nhs changesList (NoInline _ ifc) = collectChangesHr nhs changesList ifc
collectChangesHr nhs changesList (Parallel ifc) = collectChangesHr nhs changesList ifc
collectChangesHr nhs changesList (IfThenElse ifc1 ifc2) = collectChangesHr nhs (collectChangesHr nhs changesList ifc1) ifc2
collectChangesHr nhs changesList (While ifc1 ifc2) = collectChangesHr nhs (collectChangesHr nhs changesList ifc1) ifc2
collectChangesHr (nodeId,hs) changesList _ = changesList
instance CollectChangesHr Interface where
collectChangesHr (nodeId,hs) changesList ifc = changesList ++ (map (genChange nodeId) $ zip [((length changesList) + 1)..] $ filter (mustChange hs . fst) (tupleZipList (interfaceOutput ifc, interfaceOutputType ifc)))
class NotInHr a where
notInHr :: NodeId -> a -> Bool
instance NotInHr [Hierarchy] where
notInHr id hs = and $ map (notInHr id) hs
instance NotInHr Hierarchy where
notInHr id (Hierarchy nodeHs) = and $ map (notInHr id) nodeHs
instance NotInHr (Node, [Hierarchy]) where
notInHr id (node, hs) = (notInHr id node) && (notInHr id hs)
instance NotInHr Node where
notInHr id node = id /= (nodeId node)