module Math.SetCover.Exact.Knead.Symbolic (
   BitSet(..),
   Block,

   SetId, SetDim, BlockId, BlockDim, DigitId, DigitDim,
   sumBags3,
   difference,
   getRow,
   nullSet,
   disjoint,
   disjointRows,
   differenceWithRow,
   findIndices,
   filterDisjointRows,
   collectRows,
   ) where

import qualified Math.SetCover.Exact.Block as Blocks

import Control.Monad.HT ((<=<))
import Control.Applicative (liftA2, (<$>))

import qualified Data.Array.Knead.Parameterized.Render as Render
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Symbolic as Symb
import qualified Data.Array.Knead.Simple.Slice as Slice
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Simple.Symbolic ((!))
import Data.Array.Knead.Expression
         (Exp, (==*), (<*), xor, (.|.*), (.&.*), )

import qualified Data.Array.Comfort.Shape as ComfortShape
import qualified Data.Array.Comfort.Storable.Unchecked as ComfortArray
import qualified Data.Array.Comfort.Boxed as Array
import Data.Array.Comfort.Boxed (Array)

import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Value (atom)

import qualified Data.Word as Word
import qualified Data.Int as Int
import Data.Set (Set)

import Prelude2010
import Prelude ()



class (MultiValue.Logic block) => BitSet block where
   nullBlock :: Exp block -> Exp Bool
   blocksFromSets :: (Ord a) => [Set a] -> ([[block]], [block])
   keepMinimumBit :: Exp block -> Exp block

instance BitSet Word.Word8 where
   nullBlock block = block ==* Expr.zero
   blocksFromSets sets = Blocks.blocksFromSets sets
   keepMinimumBit = keepMinimumBitPrim

instance BitSet Word.Word16 where
   nullBlock block = block ==* Expr.zero
   blocksFromSets sets = Blocks.blocksFromSets sets
   keepMinimumBit = keepMinimumBitPrim

instance BitSet Word.Word32 where
   nullBlock block = block ==* Expr.zero
   blocksFromSets sets = Blocks.blocksFromSets sets
   keepMinimumBit = keepMinimumBitPrim

instance BitSet Word.Word64 where
   nullBlock block = block ==* Expr.zero
   blocksFromSets sets = Blocks.blocksFromSets sets
   keepMinimumBit = keepMinimumBitPrim

keepMinimumBitPrim ::
   (MultiValue.Additive a, MultiValue.Logic a) => Exp a -> Exp a
keepMinimumBitPrim =
   Expr.liftM (\x -> MultiValue.and x =<< MultiValue.neg x)



type Block = Word.Word64

-- SetId must allow negative numbers since it is used for empty plain Arrays
type SetId = Int.Int32
type BlockId = Int.Int32
type DigitId = Word.Word32

type SetDim = Shape.ZeroBased SetId
type BlockDim = Shape.ZeroBased BlockId
type DigitDim = Shape.ZeroBased DigitId


addLow, addHigh :: MultiValue.Logic a => Exp a -> Exp a -> Exp a -> Exp a
addLow a b c = a `xor` b `xor` c
addHigh a b c = c.&.*(a.|.*b) .|.* a.&.*b

add2 ::
   IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
       IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
add2 =
   Render.run $ \xs ->
   Render.MapAccumLSimple
      (Expr.modify2 atom (atom,atom) $ \carry (a,b) ->
         (addHigh a b carry, addLow a b carry))
      (Symb.fill (Expr.fst (Symb.shape xs)) Expr.zero)
      (halfBags xs)


zbAtom :: Shape.ZeroBased (MultiValue.Atom a)
zbAtom = Shape.ZeroBased atom

halfBags ::
   Symb.Array ((SetDim, BlockDim), DigitDim) Block ->
   Symb.Array ((SetDim, BlockDim), DigitDim) (Block,Block)
halfBags xs =
   Symb.map
      (Expr.modify2 ((zbAtom, atom), zbAtom) ((atom,atom),atom)
         (\((Shape.ZeroBased numSets, _shBlocks), Shape.ZeroBased numDigits)
           ((n,j),k) ->
            elseIfThen Expr.zero (k<*numDigits) $
            Expr.zip
               (xs ! Expr.zip (Expr.zip (2*n) j) k)
               (elseIfThen Expr.zero (2*n+1<*numSets)
                  (xs ! Expr.zip (Expr.zip (2*n+1) j) k)))
         (Symb.shape xs)) $
   Symb.id
      (Expr.modify ((zbAtom, atom), zbAtom)
         (\((Shape.ZeroBased numSets, shBlocks), Shape.ZeroBased numDigits) ->
            ((Shape.ZeroBased (Expr.idiv (numSets+1) 2), shBlocks),
             Shape.ZeroBased (numDigits+1)))
         (Symb.shape xs))

elseIfThen :: MultiValue.C a => Exp a -> Exp Bool -> Exp a -> Exp a
elseIfThen y c x = Expr.ifThenElse c x y


removeDimension ::
   IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
       IO (Phys.Array (DigitDim, BlockDim) Block))
removeDimension =
   Render.run $
      Symb.fix .
      Slice.apply
         (Slice.first (Slice.pickFst Expr.zero)
          `Slice.compose`
          Slice.transpose)

sumLoop ::
   IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block ->
       IO (Phys.Array (DigitDim, BlockDim) Block))
sumLoop = do
   runAdd2 <- add2
   remDim <- removeDimension

   let go xs =
         if (ComfortShape.zeroBasedSize $ fst $ fst $ Phys.shape xs) > 1
           then go =<< runAdd2 xs
           else remDim xs

   return go

addSingleDim :: Phys.Array sh a -> Phys.Array (sh,DigitDim) a
addSingleDim = ComfortArray.mapShape (flip (,) (Shape.ZeroBased 1))

{-
ToDo:
We could use a carry-save adder that would enable more parallelism.
Unfortunately, currently we cannot benefit from this opportunity.
-}
_sumBags ::
   IO (Phys.Array (SetDim,BlockDim) Block ->
       IO (Phys.Array (DigitDim,BlockDim) Block))
_sumBags = (.addSingleDim) <$> sumLoop


{- |
A faster first addition step.
In the first addition we do not need to propagate carry.
We use this fact for reducing the number of rows to a third.
-}
_add3 ::
   IO (Phys.Array (SetDim, BlockDim) Block ->
       IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
_add3 =
   Render.run $ \xs ->
   Symb.mapWithIndex
      (Expr.modify2 (atom,atom) (atom,atom,atom) $ \(_,k) (a,b,c) ->
         Expr.ifThenElse (k ==* Expr.zero) (addLow a b c) (addHigh a b c))
      (Slice.apply (Slice.extrudeSnd digitDim2) $ thirdBags xs)

add3 ::
   IO (Phys.Array (SetDim, BlockDim) Block ->
       IO (Phys.Array ((SetDim, BlockDim), DigitDim) Block))
add3 =
   Render.run $
   Render.AddDimension digitDim2
      (Expr.modify2 atom (atom,atom,atom) $ \k (a,b,c) ->
         Expr.ifThenElse (k ==* Expr.zero) (addLow a b c) (addHigh a b c))
   .
   thirdBags

digitDim2 :: Exp DigitDim
digitDim2 = Expr.compose $ Shape.ZeroBased $ Expr.fromInteger' 2

thirdBags ::
   Symb.Array (SetDim, BlockDim) Block ->
   Symb.Array (SetDim, BlockDim) (Block,Block,Block)
thirdBags xs =
   Symb.map
      (Expr.modify (atom,atom)
         (\(n,j) ->
            Expr.zip3
               (xs ! Expr.zip (3*n) j)
               (condAccess xs (3*n+1) j)
               (condAccess xs (3*n+2) j))) $
   Symb.id
      (Expr.mapFst
         (Expr.modify zbAtom
            (\(Shape.ZeroBased numSets) ->
               Shape.ZeroBased $ Expr.idiv (numSets+2) 3))
         (Symb.shape xs))

condAccess ::
   Symb.Array (SetDim, BlockDim) Block -> Exp SetId -> Exp BlockId -> Exp Block
condAccess xs n j =
   Expr.ifThenElse (n <* Shape.zeroBasedSize (Expr.fst (Symb.shape xs)))
      (xs ! Expr.zip n j) Expr.zero

sumBags3 ::
   IO (Phys.Array (SetDim,BlockDim) Block ->
       IO (Phys.Array (DigitDim,BlockDim) Block))
sumBags3 = liftA2 (<=<) sumLoop add3


difference :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a
difference x y = x .&.* Expr.complement y

differenceWithRow ::
   (Shape.C k, MultiValue.Logic block) =>
   Symb.Array BlockDim block -> Exp (Shape.Index k) ->
   Symb.Array (k,BlockDim) block -> Symb.Array BlockDim block
differenceWithRow x k bag =
   Symb.zipWith difference x (getRow k bag)


disjoint :: (BitSet block) => Exp block -> Exp block -> Exp Bool
disjoint x y  =  nullBlock $ x .&.* y

getRow ::
   (Shape.C k, MultiValue.C block) =>
   Exp (Shape.Index k) ->
   Symb.Array (k, BlockDim) block -> Symb.Array BlockDim block
getRow k = Slice.apply (Slice.pickFst k)

nullSet :: (BitSet block) => Symb.Array BlockDim block -> Exp Bool
nullSet =
   Expr.maybe Expr.true (const Expr.false) . Symb.findAll (Expr.not . nullBlock)

disjointRow ::
   (BitSet block) =>
   Exp SetId -> Exp SetId -> Symb.Array (SetDim, BlockDim) block -> Exp Bool
disjointRow k0 k1 bag =
   nullSet $ Symb.zipWith (.&.*) (getRow k0 bag) (getRow k1 bag)

disjointRows ::
   (BitSet block) =>
   Exp SetId -> Symb.Array (SetDim,BlockDim) block -> Symb.Array SetDim Bool
disjointRows k0 sets =
   Symb.map
      (\k1 -> disjointRow k0 k1 sets)
      (Symb.id (Expr.fst (Symb.shape sets)))


findIndices ::
   Symb.Array SetDim Bool -> Render.MapFilter SetDim (SetId,Bool) SetId
findIndices arr =
   Render.MapFilter Expr.fst Expr.snd
      (Symb.zip (Symb.id $ Symb.shape arr) arr)

collectRows ::
   (MultiValue.C block) =>
   Symb.Array SetDim SetId ->
   Symb.Array (SetDim,BlockDim) block -> Symb.Array (SetDim,BlockDim) block
collectRows rows sets =
   Symb.backpermute
      (Expr.mapFst (const $ Symb.shape rows) (Symb.shape sets))
      (Expr.mapFst (rows!))
      sets

filterDisjointRows ::
   IO (SetId ->
       (Array SetDim label, Phys.Array (SetDim,BlockDim) Block) ->
       IO (Array SetDim label, Phys.Array (SetDim,BlockDim) Block))
filterDisjointRows = do
   disjRows <- Render.run $ \k sets -> findIndices $ disjointRows k sets
   collect <- Render.run collectRows
   return $ \k (labels,sets) -> do
      perm <- disjRows k sets
      liftA2 (,)
         (Array.fromList (Phys.shape perm) . map (labels Array.!)
            <$> Phys.toList perm)
         (collect perm sets)