module Spark.Core.Internal.CachingUntyped(
cachingType,
autocacheGen
) where
import Control.Monad.Except
import Spark.Core.Internal.Caching
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.PathsUntyped()
import Spark.Core.Internal.DAGStructures
import Spark.Core.StructuresInternal
cachingType :: UntypedNode -> CacheTry NodeCachingType
cachingType n = case nodeOp n of
NodeLocalOp _ -> pure Stop
NodeAggregatorReduction _ -> pure Stop
NodeAggregatorLocalReduction _ -> pure Stop
NodeOpaqueAggregator _ -> pure Stop
NodeLocalLit _ _ -> pure Stop
NodeStructuredTransform _ -> pure Through
NodeDistributedLit _ _ -> pure Through
NodeDistributedOp so | soName so == opnameCache ->
pure $ CacheOp (vertexToId n)
NodeDistributedOp so | soName so == opnameUnpersist ->
case nodeParents n of
[n'] -> pure $ UncacheOp (vertexToId n) (vertexToId n')
_ -> throwError "Node is not valid uncache node"
NodeDistributedOp so | soName so == opnameAutocache ->
pure $ AutocacheOp (vertexToId n)
NodeDistributedOp _ -> pure Through
NodeBroadcastJoin -> pure Through
NodeGroupedReduction _ -> pure Stop
NodeReduction _ -> pure Stop
NodePointer _ -> pure Stop
autocacheGen :: AutocacheGen UntypedNode
autocacheGen = AutocacheGen {
deriveUncache = deriveUncache',
deriveIdentity = deriveIdentity'
} where
deriveIdentity' (Vertex _ un) =
let x = identity un
vid' = VertexId . unNodeId . nodeId $ x
in Vertex vid' x
deriveUncache' (Vertex _ un) =
let x = uncache un
vid' = VertexId . unNodeId . nodeId $ x
in Vertex vid' x