module Edges.NodeCounting ( NodeCounts, Node, Amount, Edges, node, nodeTargets, targets, ) where import Edges.Prelude hiding (index, toList) import Edges.Types import Edges.Functions import Edges.Instances () import qualified PrimitiveExtras.PrimArray as PrimArray import qualified PrimitiveExtras.PrimMultiArray as PrimMultiArray import qualified PrimitiveExtras.TVarArray as TVarArray import qualified DeferredFolds.UnfoldM as UnfoldM import qualified Data.Vector.Unboxed as UnboxedVector import qualified Control.Monad.Par.IO as Par import qualified Control.Monad.Par as Par hiding (runParIO) node :: Edges source target -> Node source -> NodeCounts source node edges = nodeWithAmount (edgesSourceAmount edges) nodeWithAmount :: Amount entity -> Node entity -> NodeCounts entity nodeWithAmount (Amount size) (Node index) = NodeCounts (PrimArray.oneHot size index 1) nodeTargets :: Edges source target -> Node source -> NodeCounts target nodeTargets (Edges targetAmount edgesPma) (Node sourceIndex) = let indexUnfold = fmap fromIntegral (PrimMultiArray.toUnfoldAtM edgesPma sourceIndex) indexFold = PrimArray.indexCountsFold targetAmount countPa = UnfoldM.fold indexFold indexUnfold in NodeCounts countPa {-| Count the occurrences of targets based on the occurrences of sources. -} targets :: Edges source target -> NodeCounts source -> NodeCounts target targets = targetsWithMinSourceAmount 1 {-| Count the occurrences of targets based on the occurrences of sources. This function can be used to reduce the amount of computation by excluding the nodes, which don't make much difference. -} targetsWithMinSourceAmount :: Word64 -> Edges source target -> NodeCounts source -> NodeCounts target targetsWithMinSourceAmount minSourceCount (Edges targetAmount edgesPma) (NodeCounts sourceCountsPa) = unsafePerformIO $ do targetCountsMpa <- newPrimArray targetAmount forMInAscendingRange_ 0 (sizeofPrimArray sourceCountsPa) $ \ sourceIndex -> let sourceCount = indexPrimArray sourceCountsPa sourceIndex in if sourceCount >= minSourceCount then UnfoldM.forM_ (fmap fromIntegral $ PrimMultiArray.toUnfoldAtM edgesPma sourceIndex) $ \ targetIndex -> do targetCount <- readPrimArray targetCountsMpa targetIndex writePrimArray targetCountsMpa targetIndex (targetCount + sourceCount) else return () targetCountsPa <- unsafeFreezePrimArray targetCountsMpa return (NodeCounts targetCountsPa) {-| Count the occurrences of targets based on the occurrences of sources. Utilizes parallellism. -} targetsParallelly :: Edges source target -> NodeCounts source -> NodeCounts target targetsParallelly (Edges targetAmount edgesPma) (NodeCounts sourceCountsPa) = unsafePerformIO $ Par.runParIO $ do targetCountVarTable <- liftIO (TVarArray.new 0 targetAmount) Par.parFor (Par.InclusiveRange 0 (pred (sizeofPrimArray sourceCountsPa))) $ \ sourceIndex -> case indexPrimArray sourceCountsPa sourceIndex of 0 -> return () sourceCount -> liftIO $ UnfoldM.forM_ (PrimMultiArray.toUnfoldAtM edgesPma sourceIndex) $ \ targetIndex -> TVarArray.modifyAt targetCountVarTable (fromIntegral targetIndex) (+ sourceCount) targetCountsPa <- liftIO (TVarArray.freezeAsPrimArray targetCountVarTable) return (NodeCounts targetCountsPa)