module Math.SetCover.Exact.Knead.Symbolic (
BitSet(..),
Block,
SetId, SetDim, BlockId, BlockDim, DigitId, DigitDim,
sumBags3,
difference,
getRow,
nullSet,
disjoint,
disjointRows,
differenceWithRow,
findIndices,
filterDisjointRows,
collectRows,
) where
import qualified Math.SetCover.Exact.Block as Blocks
import Control.Monad.HT ((<=<))
import Control.Applicative (liftA2, (<$>))
import qualified Data.Array.Knead.Parameterized.Render as Render
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Symbolic as Symb
import qualified Data.Array.Knead.Simple.Slice as Slice
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Simple.Symbolic ((!))
import Data.Array.Knead.Expression
(Exp, (==*), (<*), xor, (.|.*), (.&.*), )
import qualified Data.Array.Comfort.Shape as ComfortShape
import qualified Data.Array.Comfort.Storable.Unchecked as ComfortArray
import qualified Data.Array.Comfort.Boxed as Array
import Data.Array.Comfort.Boxed (Array)
import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom)
import qualified Data.Word as Word
import qualified Data.Int as Int
import Data.Set (Set)
import Prelude2010
import Prelude ()
class (MultiValue.Logic block) => BitSet block where
nullBlock :: Exp block -> Exp Bool
blocksFromSets :: (Ord a) => [Set a] -> ([[block]], [block])
keepMinimumBit :: Exp block -> Exp block
instance BitSet Word.Word8 where
nullBlock block = block ==* Expr.zero
blocksFromSets sets = Blocks.blocksFromSets sets
keepMinimumBit = keepMinimumBitPrim
instance BitSet Word.Word16 where
nullBlock block = block ==* Expr.zero
blocksFromSets sets = Blocks.blocksFromSets sets
keepMinimumBit = keepMinimumBitPrim
instance BitSet Word.Word32 where
nullBlock block = block ==* Expr.zero
blocksFromSets sets = Blocks.blocksFromSets sets
keepMinimumBit = keepMinimumBitPrim
instance BitSet Word.Word64 where
nullBlock block = block ==* Expr.zero
blocksFromSets sets = Blocks.blocksFromSets sets
keepMinimumBit = keepMinimumBitPrim
keepMinimumBitPrim ::
(MultiValue.Additive a, MultiValue.Logic a) => Exp a -> Exp a
keepMinimumBitPrim =
Expr.liftM (\x -> MultiValue.and x =<< MultiValue.neg x)
type Block = Word.Word64
type SetId = Int.Int32
type BlockId = Int.Int32
type DigitId = Word.Word32
type SetDim = Shape.ZeroBased SetId
type BlockDim = Shape.ZeroBased BlockId
type DigitDim = Shape.ZeroBased DigitId
addLow, addHigh :: MultiValue.Logic a => Exp a -> Exp a -> Exp a -> Exp a
addLow a b c = a `xor` b `xor` c
addHigh a b c = c.&.*(a.|.*b) .|.* a.&.*b
add2 ::
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
add2 =
Render.run $ \xs ->
Render.MapAccumLSimple
(Expr.modify2 atom (atom,atom) $ \carry (a,b) ->
(addHigh a b carry, addLow a b carry))
(Symb.fill (Expr.fst (Symb.shape xs)) Expr.zero)
(halfBags xs)
zbAtom :: Shape.ZeroBased (MultiValue.Atom a)
zbAtom = Shape.ZeroBased atom
halfBags ::
Symb.Array ((SetDim, BlockDim), DigitDim) Block ->
Symb.Array ((SetDim, BlockDim), DigitDim) (Block,Block)
halfBags xs =
Symb.map
(Expr.modify2 ((zbAtom, atom), zbAtom) ((atom,atom),atom)
(\((Shape.ZeroBased numSets, _shBlocks), Shape.ZeroBased numDigits)
((n,j),k) ->
elseIfThen Expr.zero (k<*numDigits) $
Expr.zip
(xs ! Expr.zip (Expr.zip (2*n) j) k)
(elseIfThen Expr.zero (2*n+1<*numSets)
(xs ! Expr.zip (Expr.zip (2*n+1) j) k)))
(Symb.shape xs)) $
Symb.id
(Expr.modify ((zbAtom, atom), zbAtom)
(\((Shape.ZeroBased numSets, shBlocks), Shape.ZeroBased numDigits) ->
((Shape.ZeroBased (Expr.idiv (numSets+1) 2), shBlocks),
Shape.ZeroBased (numDigits+1)))
(Symb.shape xs))
elseIfThen :: MultiValue.C a => Exp a -> Exp Bool -> Exp a -> Exp a
elseIfThen y c x = Expr.ifThenElse c x y
removeDimension ::
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
IO (Phys.Array (DigitDim, BlockDim) Block))
removeDimension =
Render.run $
Symb.fix .
Slice.apply
(Slice.first (Slice.pickFst Expr.zero)
`Slice.compose`
Slice.transpose)
sumLoop ::
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
IO (Phys.Array (DigitDim, BlockDim) Block))
sumLoop = do
runAdd2 <- add2
remDim <- removeDimension
let go xs =
if (ComfortShape.zeroBasedSize $ fst $ fst $ Phys.shape xs) > 1
then go =<< runAdd2 xs
else remDim xs
return go
addSingleDim :: Phys.Array sh a -> Phys.Array (sh,DigitDim) a
addSingleDim = ComfortArray.mapShape (flip (,) (Shape.ZeroBased 1))
_sumBags ::
IO (Phys.Array (SetDim,BlockDim) Block ->
IO (Phys.Array (DigitDim,BlockDim) Block))
_sumBags = (.addSingleDim) <$> sumLoop
_add3 ::
IO (Phys.Array (SetDim, BlockDim) Block ->
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
_add3 =
Render.run $ \xs ->
Symb.mapWithIndex
(Expr.modify2 (atom,atom) (atom,atom,atom) $ \(_,k) (a,b,c) ->
Expr.ifThenElse (k ==* Expr.zero) (addLow a b c) (addHigh a b c))
(Slice.apply (Slice.extrudeSnd digitDim2) $ thirdBags xs)
add3 ::
IO (Phys.Array (SetDim, BlockDim) Block ->
IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
add3 =
Render.run $
Render.AddDimension digitDim2
(Expr.modify2 atom (atom,atom,atom) $ \k (a,b,c) ->
Expr.ifThenElse (k ==* Expr.zero) (addLow a b c) (addHigh a b c))
.
thirdBags
digitDim2 :: Exp DigitDim
digitDim2 = Expr.compose $ Shape.ZeroBased $ Expr.fromInteger' 2
thirdBags ::
Symb.Array (SetDim, BlockDim) Block ->
Symb.Array (SetDim, BlockDim) (Block,Block,Block)
thirdBags xs =
Symb.map
(Expr.modify (atom,atom)
(\(n,j) ->
Expr.zip3
(xs ! Expr.zip (3*n) j)
(condAccess xs (3*n+1) j)
(condAccess xs (3*n+2) j))) $
Symb.id
(Expr.mapFst
(Expr.modify zbAtom
(\(Shape.ZeroBased numSets) ->
Shape.ZeroBased $ Expr.idiv (numSets+2) 3))
(Symb.shape xs))
condAccess ::
Symb.Array (SetDim, BlockDim) Block -> Exp SetId -> Exp BlockId -> Exp Block
condAccess xs n j =
Expr.ifThenElse (n <* Shape.zeroBasedSize (Expr.fst (Symb.shape xs)))
(xs ! Expr.zip n j) Expr.zero
sumBags3 ::
IO (Phys.Array (SetDim,BlockDim) Block ->
IO (Phys.Array (DigitDim,BlockDim) Block))
sumBags3 = liftA2 (<=<) sumLoop add3
difference :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
difference x y = x .&.* Expr.complement y
differenceWithRow ::
(Shape.C k, MultiValue.Logic block) =>
Symb.Array BlockDim block -> Exp (Shape.Index k) ->
Symb.Array (k,BlockDim) block -> Symb.Array BlockDim block
differenceWithRow x k bag =
Symb.zipWith difference x (getRow k bag)
disjoint :: (BitSet block) => Exp block -> Exp block -> Exp Bool
disjoint x y = nullBlock $ x .&.* y
getRow ::
(Shape.C k, MultiValue.C block) =>
Exp (Shape.Index k) ->
Symb.Array (k, BlockDim) block -> Symb.Array BlockDim block
getRow k = Slice.apply (Slice.pickFst k)
nullSet :: (BitSet block) => Symb.Array BlockDim block -> Exp Bool
nullSet =
Expr.maybe Expr.true (const Expr.false) . Symb.findAll (Expr.not . nullBlock)
disjointRow ::
(BitSet block) =>
Exp SetId -> Exp SetId -> Symb.Array (SetDim, BlockDim) block -> Exp Bool
disjointRow k0 k1 bag =
nullSet $ Symb.zipWith (.&.*) (getRow k0 bag) (getRow k1 bag)
disjointRows ::
(BitSet block) =>
Exp SetId -> Symb.Array (SetDim,BlockDim) block -> Symb.Array SetDim Bool
disjointRows k0 sets =
Symb.map
(\k1 -> disjointRow k0 k1 sets)
(Symb.id (Expr.fst (Symb.shape sets)))
findIndices ::
Symb.Array SetDim Bool -> Render.MapFilter SetDim (SetId,Bool) SetId
findIndices arr =
Render.MapFilter Expr.fst Expr.snd
(Symb.zip (Symb.id $ Symb.shape arr) arr)
collectRows ::
(MultiValue.C block) =>
Symb.Array SetDim SetId ->
Symb.Array (SetDim,BlockDim) block -> Symb.Array (SetDim,BlockDim) block
collectRows rows sets =
Symb.backpermute
(Expr.mapFst (const $ Symb.shape rows) (Symb.shape sets))
(Expr.mapFst (rows!))
sets
filterDisjointRows ::
IO (SetId ->
(Array SetDim label, Phys.Array (SetDim,BlockDim) Block) ->
IO (Array SetDim label, Phys.Array (SetDim,BlockDim) Block))
filterDisjointRows = do
disjRows <- Render.run $ \k sets -> findIndices $ disjointRows k sets
collect <- Render.run collectRows
return $ \k (labels,sets) -> do
perm <- disjRows k sets
liftA2 (,)
(Array.fromList (Phys.shape perm) . map (labels Array.!)
<$> Phys.toList perm)
(collect perm sets)