module Math.SetCover.Exact.Knead.Saturated (
partitionsIO, searchIO, stepIO,
partitions,
State(..), initStateIO, updateStateIO,
) where
import qualified Math.SetCover.Exact as ESC
import Math.SetCover.Exact.Knead.Vector (Block)
import Math.SetCover.Exact.Knead.Symbolic
(SetId, SetDim, BlockId, BlockDim, blocksFromSets,
nullSet, disjoint, disjointRows, differenceWithRow,
findIndices, collectRows)
import qualified Control.Monad.HT as Monad
import Control.Monad.HT ((<=<))
import Control.Monad (foldM)
import Control.Applicative (liftA2, liftA3, pure, (<$>), (<*>))
import qualified Data.Array.Knead.Parameterized.Render as Render
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Symbolic as Symb
import qualified Data.Array.Knead.Simple.Slice as Slice
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression.Vector as ExprVec
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, (/=*), (<*), (.&.*), )
import qualified Data.Array.Comfort.Boxed as Array
import Data.Array.Comfort.Boxed (Array)
import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec
import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Vector.Memory ()
import LLVM.Extra.Multi.Value (atom)
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:+:))
import qualified System.IO.Lazy as LazyIO
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.List.Match as Match
import qualified Data.Set as Set
import qualified Data.Word as Word
import qualified Data.Int as Int
import qualified Data.Bool8 as Bool8
import qualified Data.Bits as Bits
import Data.Functor.Compose (Compose(Compose,getCompose))
import Data.Set (Set)
import Prelude2010
import Prelude ()
type NumCounters = TypeNum.D16
type Counter = Word.Word8
type Counters = LLVM.Vector NumCounters Counter
type Subblock = Word.Word8
type Block16 = LLVM.Vector TypeNum.D8 Word.Word16
bitSize :: Int
bitSize = Bits.bitSize (0::Counter)
numCounters :: Integer
numCounters =
TypeNum.integerFromSingleton
(TypeNum.singleton :: TypeNum.Singleton NumCounters)
type CounterId = Int.Int16
type BitId = Int.Int8
type CounterDim = Shape.ZeroBased CounterId
type BitDim = Shape.ZeroBased BitId
data State label =
State {
availableSubsets ::
(Array SetDim label, Phys.Array (SetDim,BlockDim) Block),
freeElements :: Phys.Array BlockDim Block,
usedSubsets :: [label]
}
initStateIO :: (Ord a) => [ESC.Assign label (Set a)] -> IO (State label)
initStateIO assigns = do
let neAssigns = filter (not . Set.null . ESC.labeledSet) assigns
(avails, freeBlocks) = blocksFromSets $ map ESC.labeledSet neAssigns
shSets = Shape.ZeroBased $ fromIntegral $ length neAssigns
free <- Phys.vectorFromList freeBlocks
avail <-
Phys.fromList (shSets, Phys.shape free) $
concatMap (Match.take freeBlocks) avails
return $
State {
availableSubsets =
(Array.fromList shSets $ map ESC.label neAssigns, avail),
freeElements = free,
usedSubsets = []
}
repVec :: Counter -> Exp Counters
repVec = Expr.fromInteger' . toInteger
incSat :: Exp Counters -> Exp Counters -> Exp Counters
incSat x y =
let maxBnd = repVec maxBound
in ExprVec.select (ExprVec.cmp LLVM.CmpEQ y maxBnd)
maxBnd (Expr.add x y)
incSatGeneric ::
LLVM.Value Counters -> LLVM.Value Counters ->
LLVM.CodeGenFunction r (LLVM.Value Counters)
incSatGeneric x y =
(\(MultiValue.Cons z) -> getCompose z)
<$>
Expr.unliftM2 incSat
(MultiValue.Cons (Compose x)) (MultiValue.Cons (Compose y))
incSatX86 :: Exp Counters -> Exp Counters -> Exp Counters
incSatX86 =
Expr.liftM2
(MultiValue.liftM2
(\xc yc ->
Compose <$>
case (getCompose xc, getCompose yc) of
(x,y) ->
Ext.run (incSatGeneric x y)
(X86.paddusb128 <*> pure x <*> pure y)))
sumRows ::
Symb.Array (SetDim, blockDim) Counters ->
Render.FoldOuterL SetDim blockDim Counters Counters
sumRows xs =
Render.FoldOuterL (flip incSatX86)
(Symb.fill (Expr.snd $ Symb.shape xs) Expr.zero) xs
extrudeBits :: Slice.T sh (sh, BitDim)
extrudeBits =
Slice.extrudeSnd $ Expr.compose $ Shape.ZeroBased $
Expr.fromInteger' $ toInteger bitSize
extrudeCounters :: Slice.T sh (sh, CounterDim)
extrudeCounters =
Slice.extrudeSnd $ Expr.compose $ Shape.ZeroBased $
Expr.fromInteger' numCounters
toCounters :: Exp Block -> Exp Counters
toCounters =
Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast))
_pickBits :: Exp BitId -> Exp Block -> Exp Counters
_pickBits k block =
repVec 1 .&.* Expr.shr (toCounters block) (ExprVec.replicate (bitPos k))
word16 :: Exp BitId -> Exp Word.Word16
word16 = Expr.liftM (MultiValue.liftM LLVM.ext) . bitPos
toBlock16 :: Exp Block -> Exp Block16
toBlock16 =
Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast))
fromBlock16 :: Exp Block16 -> Exp Counters
fromBlock16 =
Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast . getCompose))
pickBitsX86 :: Exp BitId -> Exp Block -> Exp Counters
pickBitsX86 k block =
repVec 1 .&.*
fromBlock16 (Expr.shr (toBlock16 block) (ExprVec.replicate (word16 k)))
uninterleaveBits ::
Symb.Array (SetDim, BlockDim) Block ->
Symb.Array (SetDim, (BlockDim, BitDim)) Counters
uninterleaveBits =
Symb.mapWithIndex (\ix block -> pickBitsX86 (Expr.snd (Expr.snd ix)) block) .
Slice.apply (Slice.second extrudeBits)
filterDisjointRows ::
IO (SetId ->
(Array SetDim label, Phys.Array (SetDim,BlockDim) Block) ->
IO (Array SetDim label, Phys.Array (SetDim,BlockDim) Block))
filterDisjointRows = do
disjRows <- Render.run $ \k -> findIndices . disjointRows k
collect <- Render.run collectRows
return $ \k0 (labels,sets) -> do
perm <- disjRows k0 sets
liftA2 (,)
(Array.fromList (Phys.shape perm) . map (labels Array.!)
<$> Phys.toList perm)
(collect perm sets)
updateStateIO :: IO (SetId -> State label -> LazyIO.T (State label))
updateStateIO = do
filt <- filterDisjointRows
diff <- Render.run differenceWithRow
return $ \k s ->
LazyIO.interleave $
liftA3 State
(filt k $ availableSubsets s)
(diff (freeElements s) k $ snd $ availableSubsets s)
(pure (fst (availableSubsets s) Array.! k : usedSubsets s))
mvvec :: MultiValue.T (LLVM.Vector n a) -> MultiVector.T n a
mvvec (MultiValue.Cons x) = MultiVector.Cons x
extract ::
(TypeNum.Positive n, MultiVector.C a) =>
Exp CounterId -> Exp (LLVM.Vector n a) -> Exp a
extract =
Expr.liftM2
(\(MultiValue.Cons k) v ->
flip MultiVector.extract (mvvec v) =<< LLVM.zext k)
extractBlock :: Exp CounterId -> Exp Block -> Exp Subblock
extractBlock =
Expr.liftM2
(\(MultiValue.Cons k) (MultiValue.Cons v) ->
MultiValue.Cons <$> (LLVM.extractelement v =<< LLVM.zext k))
flattenCounters ::
Symb.Array (BlockDim, BitDim) Counters ->
Symb.Array ((BlockDim,CounterDim), BitDim) Counter
flattenCounters =
Symb.mapWithIndex (\ix block -> extract (Expr.snd (Expr.fst ix)) block) .
Slice.apply (Slice.first extrudeCounters)
bitPos :: Exp BitId -> Exp Subblock
bitPos = Expr.liftM (MultiValue.liftM LLVM.bitcast)
singleBit :: Exp BitId -> Exp Subblock
singleBit = Expr.shl 1 . bitPos
argMin ::
(MultiValue.Select x, MultiValue.Select y, MultiValue.Comparison y) =>
Exp (x,y) -> Exp (x,y) -> Exp (x,y)
argMin xy0 xy1 = Expr.select (Expr.snd xy0 <* Expr.snd xy1) xy0 xy1
argMinimum ::
(Shape.C sh, Shape.Index sh ~ ix, MultiValue.Select ix) =>
Symb.Array sh Counter -> Exp ix
argMinimum = Expr.fst . Symb.fold1All argMin . Symb.mapWithIndex Expr.zip
_keepMinimum ::
IO (Phys.Array (BlockDim, BitDim) Counters ->
IO ((BlockId,CounterId),Counter))
_keepMinimum =
Render.run $ Expr.mapSnd singleBit . argMinimum . flattenCounters
argMinMasked ::
(MultiValue.Select x, MultiValue.Select y, MultiValue.Comparison y) =>
Exp (Bool, (x,y)) -> Exp (Bool, (x,y)) -> Exp (Bool, (x,y))
argMinMasked xy0 xy1 =
Expr.select (Expr.fst xy1)
(Expr.select (Expr.fst xy0)
(Expr.zip Expr.true $ argMin (Expr.snd xy0) (Expr.snd xy1))
xy1)
xy0
testBlockBit :: Exp CounterId -> Exp BitId -> Exp Block -> Exp Bool
testBlockBit k j block = Expr.shr (extractBlock k block) (bitPos j) .&.* 1 /=* 0
flattenBlockBits ::
Symb.Array BlockDim Block ->
Symb.Array ((BlockDim,CounterDim), BitDim) Bool
flattenBlockBits =
Symb.mapWithIndex
(Expr.modify2 ((atom,atom),atom) atom $ \((_n,k),j) block ->
testBlockBit k j block) .
Slice.apply (Slice.compose extrudeCounters extrudeBits)
argMinimumMasked ::
Symb.Array BlockDim Block ->
Symb.Array ((BlockDim,CounterDim), BitDim) Counter ->
Exp ((BlockId,CounterId),BitId)
argMinimumMasked free =
Expr.fst . Expr.snd . Symb.fold1All argMinMasked .
Symb.zip (flattenBlockBits free) . Symb.mapWithIndex Expr.zip
_keepMinimumMasked ::
IO (Phys.Array BlockDim Block ->
Phys.Array (BlockDim,BitDim) Counters ->
IO ((BlockId,CounterId),Counter))
_keepMinimumMasked =
Render.run $ \free ->
Expr.mapSnd singleBit . argMinimumMasked free . flattenCounters
argMinVec ::
(TypeNum.Positive n,
MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
Exp (LLVM.Vector n (x,y)) -> Exp (LLVM.Vector n (x,y)) ->
Exp (LLVM.Vector n (x,y))
argMinVec xy0 xy1 =
ExprVec.select
(ExprVec.cmp LLVM.CmpLT (ExprVec.snd xy0) (ExprVec.snd xy1)) xy0 xy1
argMinMaskedVec ::
(TypeNum.Positive n,
MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
Exp (LLVM.Vector n (Bool, (x,y))) -> Exp (LLVM.Vector n (Bool, (x,y))) ->
Exp (LLVM.Vector n (Bool, (x,y)))
argMinMaskedVec xy0 xy1 =
ExprVec.select (ExprVec.fst xy1)
(ExprVec.select (ExprVec.fst xy0)
(ExprVec.zip (ExprVec.replicate Expr.true) $
argMinVec (ExprVec.snd xy0) (ExprVec.snd xy1))
xy1)
xy0
testBlockBitVec ::
Exp BitId -> Exp Block -> Exp (LLVM.Vector NumCounters Bool)
testBlockBitVec j block =
ExprVec.cmp LLVM.CmpNE Expr.zero $ pickBitsX86 j block
flattenBlockBitsVec ::
Symb.Array BlockDim Block ->
Symb.Array (BlockDim,BitDim) (LLVM.Vector NumCounters Bool)
flattenBlockBitsVec =
Symb.mapWithIndex
(Expr.modify2 (atom,atom) atom $ \(_n,j) block ->
testBlockBitVec j block) .
Slice.apply extrudeBits
argMinimumMaskedVec ::
Symb.Array BlockDim Block ->
Symb.Array (BlockDim, BitDim) Counters ->
Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter)))
argMinimumMaskedVec free =
Symb.fold1All argMinMaskedVec .
Symb.zipWith ExprVec.zip (flattenBlockBitsVec free) .
Symb.mapWithIndex (ExprVec.zip . ExprVec.replicate)
counterIds :: Exp (LLVM.Vector NumCounters CounterId)
counterIds = ExprVec.cons (LLVM.vector (NonEmptyC.iterate (1+) 0))
_keepMinimumMaskedVector ::
Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter))) ->
Exp ((BlockId, CounterId), BitId)
_keepMinimumMaskedVector =
Expr.liftM
(fmap (MultiValue.fst . MultiValue.snd) .
foldM (Expr.unliftM2 argMinMasked)
(MultiValue.zip (MultiValue.cons False) MultiValue.undef)
<=< MultiValueVec.dissect)
.
ExprVec.mapSnd
(ExprVec.mapFst (ExprVec.mapFst (flip ExprVec.zip counterIds)))
type
IxVector n =
MultiValue.T (LLVM.Vector n
(Bool, (((BlockId, CounterId), BitId), Counter)))
argMinMaskedVecHalf ::
(TypeNum.Positive n, TypeNum.Positive n2, (n:+:n) ~ n2,
MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
MultiValue.T (LLVM.Vector n2 (Bool, (x, y))) ->
LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n (Bool, (x, y))))
argMinMaskedVecHalf x =
Monad.liftJoin2
(Expr.unliftM2 argMinMaskedVec)
(MultiValueVec.take x)
(MultiValueVec.takeRev x)
keepMinimumMaskedCascade ::
Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter))) ->
Exp ((BlockId, CounterId), BitId)
keepMinimumMaskedCascade =
Expr.fst . Expr.snd
.
Expr.liftM
(\x16 -> do
x8 <- argMinMaskedVecHalf x16
x4 <- argMinMaskedVecHalf x8
x2 <- argMinMaskedVecHalf x4
Monad.liftJoin2 (Expr.unliftM2 argMinMasked)
(MultiValueVec.extract (LLVM.valueOf 0) (x2 :: IxVector TypeNum.D2))
(MultiValueVec.extract (LLVM.valueOf 1) x2))
.
ExprVec.mapSnd
(ExprVec.mapFst (ExprVec.mapFst (flip ExprVec.zip counterIds)))
keepMinimumMaskedVec ::
IO (Phys.Array BlockDim Block ->
Phys.Array (BlockDim, BitDim) Counters ->
IO ((BlockId,CounterId),Subblock))
keepMinimumMaskedVec =
Render.run $ \free ->
Expr.mapSnd singleBit . keepMinimumMaskedCascade .
argMinimumMaskedVec free
affectedRows ::
IO (Phys.Array (SetDim,BlockDim) Block ->
((BlockId,CounterId),Subblock) -> IO [SetId])
affectedRows = do
affected <-
Render.run $ \arr ((j,k),bit) ->
findIndices $ Symb.map (Expr.not . disjoint bit . extractBlock k) $
Slice.apply (Slice.pickSnd j) $ Symb.fix arr
return $ \arr bit -> Phys.toList =<< affected arr bit
minimize ::
IO (Phys.Array BlockDim Block ->
Phys.Array (SetDim,BlockDim) Block -> IO [SetId])
minimize = do
smRows <- Render.run (sumRows . uninterleaveBits)
affected <- affectedRows
keepMin <- keepMinimumMaskedVec
return $ \free arr -> affected arr =<< keepMin free =<< smRows arr
stepIO :: IO (State label -> LazyIO.T [State label])
stepIO = do
update <- updateStateIO
minim <- minimize
return $ \s ->
mapM (flip update s) =<<
LazyIO.interleave (minim (freeElements s) (snd $ availableSubsets s))
searchIO :: IO (State label -> LazyIO.T [[label]])
searchIO = do
stp <- stepIO
nullSt <- Render.run (Expr.bool8FromP . nullSet)
let srch s = do
isNull <- LazyIO.interleave $ nullSt (freeElements s)
if Bool8.toBool isNull
then return [usedSubsets s]
else concat <$> (mapM srch =<< stp s)
return srch
partitionsIO :: (Ord a) => IO ([ESC.Assign label (Set a)] -> LazyIO.T [[label]])
partitionsIO = do
srch <- searchIO
return $ srch <=< LazyIO.interleave . initStateIO
partitions :: (Ord a) => [ESC.Assign label (Set a)] -> [[label]]
partitions =
let parts = unsafePerformIO partitionsIO
in unsafePerformIO . LazyIO.run . parts