module Language.Paraiso.Optimization.DependencyAnalysis (
writeGrouping
) where
import qualified Data.Graph.Inductive as FGL
import qualified Data.Set as Set
import qualified Data.Vector as V
import qualified Language.Paraiso.Annotation as Anot
import qualified Language.Paraiso.Annotation.Allocation as Alloc
import qualified Language.Paraiso.Annotation.Boundary as Boundary
import qualified Language.Paraiso.Annotation.Dependency as Depend
import Language.Paraiso.OM
import qualified Language.Paraiso.OM.DynValue as DVal
import Language.Paraiso.OM.Graph
import qualified Language.Paraiso.OM.Realm as Realm
import qualified Language.Paraiso.Optimization.Graph as Opt
writeGrouping :: Opt.Ready v g => OM v g Anot.Annotation -> OM v g Anot.Annotation
writeGrouping om0 = om { kernels = kernelsRet}
where
om = Opt.gmap dependencyAnalysis om0
kernels0 = kernels om
graphs0 = V.map dataflow $ kernels0
graphsSize = V.length graphs0
groupCount = flip V.map graphs0
(\graph -> (1+) $ maximum $ (1 : ) $ concat $
map (map Depend.getKernelGroupID . a2k . getA) $
map snd $ FGL.labNodes graph)
a2k :: Anot.Annotation -> [Depend.KernelWriteGroup]
a2k a = case Anot.toMaybe a of
Just kwg -> [kwg]
Nothing -> []
graphsRet = V.generate graphsSize (\i -> renumber i $ graphs0 V.! i)
kernelsRet = flip V.imap kernels0
(\i kern -> kern { dataflow = graphsRet V.! i})
renumber idx graph =
let diff = V.sum$ V.take idx groupCount in
flip nmap graph $
Anot.map (Depend.OMWriteGroup . (diff+) . Depend.getKernelGroupID)
dependencyAnalysis :: Opt.Ready v g => Opt.OptimizationOf v g
dependencyAnalysis graph = imap update graph
where
update :: FGL.Node -> Anot.Annotation -> Anot.Annotation
update i = setGroup i . (Anot.set $ calcList i) . (Anot.set $ dependencyList i) . (Anot.set $ indirectList i)
dependencyList idx =
Depend.Direct $
V.toList $
V.map fst $
V.filter snd $
V.imap (\sidx dep -> (sidxToIdx V.! sidx, dep)) $
dependMatrixWrite V.! idx
indirectList idx =
Depend.Indirect $
V.toList $
V.map fst $
V.filter snd $
V.imap (\sidx dep -> (sidxToIdx V.! sidx, dep)) $
indirectMatrixWrite V.! idx
calcList idx =
Depend.Calc $
calcMatrixWrite V.! idx
setGroup idx =
case idxToAlloc V.! idx of
Alloc.Manifest -> Anot.set $ Depend.KernelWriteGroup (kernelGroup V.! idx)
_ -> id
sidxSize = V.length sidxToIdx
idxSize = FGL.noNodes graph
kernelGroup :: V.Vector Int
kernelGroup = V.generate idxSize inner
where
inner idx
| length pres == 0 = 0
| not (null coGroups) = head coGroups
| otherwise = 1 + maximum (map (kernelGroup V.!) pres)
where
pres = takeWhile (<idx) manifestNodes
existingGroups = Set.toList $ Set.fromList $ map (kernelGroup V.!) pres
groupMember grp = filter ((==grp) . (kernelGroup V.!)) pres
coGroups = filter (and . map (coexist idx) . groupMember) existingGroups
coexist :: FGL.Node -> FGL.Node -> Bool
coexist idx jdx
| (idxToAlloc V.! idx /= Alloc.Manifest || idxToAlloc V.! jdx /= Alloc.Manifest )
= error "coexistence not defined for non-Manifest nodes"
| idx == jdx = True
| idx < jdx = coexist jdx idx
| otherwise = (not dependent') && sameShape
where
dependent' = indirectMatrixWrite V.! idx V.! (idxToSidx V.! jdx)
sameShape =
(idxToRealm V.! idx) == (idxToRealm V.! jdx) &&
(idxToValid V.! idx) == (idxToValid V.! jdx)
idxToAlloc :: V.Vector Alloc.Allocation
idxToAlloc = V.fromList $
map (\(_, nd) -> f' $ Anot.toMaybe $ getA nd) $
FGL.labNodes graph
where
f' (Just x) = x
f' Nothing = error "writeGrouping must be done after decideAllocation"
idxToValid = idxToValid' graph
idxToValid' :: (Opt.Ready v g) => (Graph v g Anot.Annotation) -> V.Vector (Boundary.Valid g)
idxToValid' _ = V.fromList $
map (\(_, nd) -> f' $ Anot.toMaybe $ getA nd) $
FGL.labNodes graph
where
f' (Just x) = x
f' Nothing = error "boundaryAnalysis must be done after decideAllocation"
idxToRealm :: V.Vector Realm.Realm
idxToRealm = V.generate idxSize inner
where
inner idx = case FGL.lab graph idx of
Just (NValue (DVal.DynValue r _)_) -> r
Just _ -> error "realm required for non-Value node"
Nothing -> error "indexing mismatch"
manifestNodes :: [FGL.Node]
manifestNodes =
map fst $
filter snd $
map (\(idx, nd) -> (idx, (==Just Alloc.Manifest) $ Anot.toMaybe $ getA nd)) $
FGL.labNodes graph
sidxToIdx :: V.Vector Int
sidxToIdx =
V.map fst $
V.filter snd $
V.imap (\idx isStrict' -> (idx, isStrict')) $
isStrict
idxToSidx :: V.Vector Int
idxToSidx = V.generate idxSize inner
where
inner idx =
head $
(++ [error $ show idx ++ " is not a Strict node"]) $
V.toList $
V.map fst $
V.filter ((==idx) . snd) $
V.imap (\sidx' idx' -> (sidx', idx')) $
sidxToIdx
isStrict :: V.Vector Bool
isStrict = V.generate idxSize $ \idx ->
case idxToAlloc V.! idx of
Alloc.Manifest -> True
Alloc.Existing -> True
Alloc.Delayed -> False
dependMatrixWrite :: V.Vector (V.Vector Bool)
dependMatrixWrite = V.generate idxSize dependRowWrite
dependMatrixRead :: V.Vector (V.Vector Bool)
dependMatrixRead = V.generate idxSize dependRowRead
dependRowWrite idx = foldl mergeRow allFalseRow $ map (dependMatrixRead V.!) $ FGL.pre graph idx
dependRowRead idx
| isStrict V.! idx = V.map (==idx) sidxToIdx
| otherwise = dependMatrixWrite V.! idx
indirectMatrixWrite :: V.Vector (V.Vector Bool)
indirectMatrixWrite = V.generate idxSize indirectRowWrite
indirectMatrixRead :: V.Vector (V.Vector Bool)
indirectMatrixRead = V.generate idxSize indirectRowRead
indirectRowWrite idx = foldl mergeRow allFalseRow $ map (indirectMatrixRead V.!) $ FGL.pre graph idx
indirectRowRead idx
| isStrict V.! idx = (V.map (==idx) sidxToIdx) `mergeRow` (indirectMatrixWrite V.! idx)
| otherwise = indirectMatrixWrite V.! idx
allFalseRow = V.replicate sidxSize False
mergeRow va vb
| V.length va /= sidxSize = error "wrong size contamination in dependMatrix"
| V.length vb /= sidxSize = error "wrong size contamination in dependMatrix"
| otherwise = V.zipWith (||) va vb
calcMatrixRead :: V.Vector (Set.Set FGL.Node)
calcMatrixRead = V.generate idxSize calcRowRead
calcMatrixWrite :: V.Vector (Set.Set FGL.Node)
calcMatrixWrite = V.generate idxSize calcRowWrite
calcRowWrite idx = Set.unions $ map (calcMatrixRead V.!) $ FGL.pre graph idx
calcRowRead idx
| isStrict V.! idx = Set.fromList [idx]
| otherwise = Set.union (Set.fromList [idx]) $ calcMatrixWrite V.! idx