{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Math.SetCover.Exact.Knead.Saturated (
   partitionsIO, searchIO, stepIO,
   partitions,
   State(..), initStateIO, updateStateIO,
   ) where

import qualified Math.SetCover.Exact as ESC
import Math.SetCover.Exact.Knead.Vector (Block)
import Math.SetCover.Exact.Knead.Symbolic
         (SetId, SetDim, BlockId, BlockDim, blocksFromSets,
          nullSet, disjoint, disjointRows, differenceWithRow,
          findIndices, collectRows)

import qualified Control.Monad.HT as Monad
import Control.Monad.HT ((<=<))
import Control.Monad (foldM)
import Control.Applicative (liftA2, liftA3, pure, (<$>), (<*>))

import qualified Data.Array.Knead.Parameterized.Render as Render
import qualified Data.Array.Knead.Simple.Physical as Phys
import qualified Data.Array.Knead.Simple.Symbolic as Symb
import qualified Data.Array.Knead.Simple.Slice as Slice
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Expression.Vector as ExprVec
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, (/=*), (<*), (.&.*), )

import qualified Data.Array.Comfort.Boxed as Array
import Data.Array.Comfort.Boxed (Array)

import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec
import qualified LLVM.Extra.Multi.Value as MultiValue
import LLVM.Extra.Multi.Vector.Memory ()
import LLVM.Extra.Multi.Value (atom)

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Data.Num.Decimal ((:+:))

import qualified System.IO.Lazy as LazyIO
import System.IO.Unsafe (unsafePerformIO)

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.List.Match as Match
import qualified Data.Set as Set
import qualified Data.Word as Word
import qualified Data.Int as Int
import qualified Data.Bool8 as Bool8
import qualified Data.Bits as Bits
import Data.Functor.Compose (Compose(Compose,getCompose))
import Data.Set (Set)

import Prelude2010
import Prelude ()


type NumCounters = TypeNum.D16
type Counter = Word.Word8
type Counters = LLVM.Vector NumCounters Counter
type Subblock = Word.Word8
type Block16 = LLVM.Vector TypeNum.D8 Word.Word16
-- type Block128 = LLVM.WordN TypeNum.D128

bitSize :: Int
bitSize = Bits.bitSize (0::Counter)

numCounters :: Integer
numCounters =
   TypeNum.integerFromSingleton
      (TypeNum.singleton :: TypeNum.Singleton NumCounters)

type CounterId = Int.Int16
type BitId = Int.Int8

type CounterDim = Shape.ZeroBased CounterId
type BitDim = Shape.ZeroBased BitId


data State label =
   State {
      availableSubsets ::
         (Array SetDim label, Phys.Array (SetDim,BlockDim) Block),
      freeElements :: Phys.Array BlockDim Block,
      usedSubsets :: [label]
   }

initStateIO :: (Ord a) => [ESC.Assign label (Set a)] -> IO (State label)
initStateIO assigns = do
   let neAssigns = filter (not . Set.null . ESC.labeledSet) assigns
       (avails, freeBlocks) = blocksFromSets $ map ESC.labeledSet neAssigns
       shSets = Shape.ZeroBased $ fromIntegral $ length neAssigns
   free <- Phys.vectorFromList freeBlocks
   avail <-
      Phys.fromList (shSets, Phys.shape free) $
      concatMap (Match.take freeBlocks) avails
   return $
      State {
         availableSubsets =
            (Array.fromList shSets $ map ESC.label neAssigns, avail),
         freeElements = free,
         usedSubsets = []
      }


repVec :: Counter -> Exp Counters
repVec = Expr.fromInteger' . toInteger

{- |
We add bytes with saturation.
The first operand must consist exclusively of zeros and ones.

With saturation we perform the same as the unoptimized algorithm
if the element with minimum occurrence is contained in at most 254 sets.
This is pretty much and should never happen.
If all elements occur in more than 254 sets
then we will choose the first one
which might lead to an unnecessary long case analysis,
but would still yield correct results.
-}
incSat :: Exp Counters -> Exp Counters -> Exp Counters
incSat x y =
   let maxBnd = repVec maxBound
   in  ExprVec.select (ExprVec.cmp LLVM.CmpEQ y maxBnd)
         maxBnd (Expr.add x y)

incSatGeneric ::
   LLVM.Value Counters -> LLVM.Value Counters ->
   LLVM.CodeGenFunction r (LLVM.Value Counters)
incSatGeneric x y =
   (\(MultiValue.Cons z) -> getCompose z)
   <$>
   Expr.unliftM2 incSat
      (MultiValue.Cons (Compose x)) (MultiValue.Cons (Compose y))

incSatX86 :: Exp Counters -> Exp Counters -> Exp Counters
incSatX86 =
   Expr.liftM2
      (MultiValue.liftM2
         (\xc yc ->
            Compose <$>
            case (getCompose xc, getCompose yc) of
               (x,y) ->
                  Ext.run (incSatGeneric x y)
                     (X86.paddusb128 <*> pure x <*> pure y)))

sumRows ::
   Symb.Array (SetDim, blockDim) Counters ->
   Render.FoldOuterL SetDim blockDim Counters Counters
sumRows xs =
   Render.FoldOuterL (flip incSatX86)
      (Symb.fill (Expr.snd $ Symb.shape xs) Expr.zero) xs


extrudeBits :: Slice.T sh (sh, BitDim)
extrudeBits =
   Slice.extrudeSnd $ Expr.compose $ Shape.ZeroBased $
   Expr.fromInteger' $ toInteger bitSize

extrudeCounters :: Slice.T sh (sh, CounterDim)
extrudeCounters =
   Slice.extrudeSnd $ Expr.compose $ Shape.ZeroBased $
   Expr.fromInteger' numCounters

toCounters :: Exp Block -> Exp Counters
toCounters =
   Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast))

_pickBits :: Exp BitId -> Exp Block -> Exp Counters
_pickBits k block =
   repVec 1 .&.* Expr.shr (toCounters block) (ExprVec.replicate (bitPos k))


word16 :: Exp BitId -> Exp Word.Word16
word16 = Expr.liftM (MultiValue.liftM LLVM.ext) . bitPos

toBlock16 :: Exp Block -> Exp Block16
toBlock16 =
   Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast))

fromBlock16 :: Exp Block16 -> Exp Counters
fromBlock16 =
   Expr.liftM (MultiValue.liftM (fmap Compose . LLVM.bitcast . getCompose))

pickBitsX86 :: Exp BitId -> Exp Block -> Exp Counters
pickBitsX86 k block =
   repVec 1 .&.*
   fromBlock16 (Expr.shr (toBlock16 block) (ExprVec.replicate (word16 k)))

uninterleaveBits ::
   Symb.Array (SetDim, BlockDim) Block ->
   Symb.Array (SetDim, (BlockDim, BitDim)) Counters
uninterleaveBits =
   Symb.mapWithIndex (\ix block -> pickBitsX86 (Expr.snd (Expr.snd ix)) block) .
   Slice.apply (Slice.second extrudeBits)


filterDisjointRows ::
   IO (SetId ->
       (Array SetDim label, Phys.Array (SetDim,BlockDim) Block) ->
       IO (Array SetDim label, Phys.Array (SetDim,BlockDim) Block))
filterDisjointRows = do
   disjRows <- Render.run $ \k -> findIndices . disjointRows k
   collect <- Render.run collectRows
   return $ \k0 (labels,sets) -> do
      perm <- disjRows k0 sets
      liftA2 (,)
         (Array.fromList (Phys.shape perm) . map (labels Array.!)
            <$> Phys.toList perm)
         (collect perm sets)

updateStateIO :: IO (SetId -> State label -> LazyIO.T (State label))
updateStateIO = do
   filt <- filterDisjointRows
   diff <- Render.run differenceWithRow
   return $ \k s ->
      LazyIO.interleave $
      liftA3 State
         (filt k $ availableSubsets s)
         (diff (freeElements s) k $ snd $ availableSubsets s)
         (pure (fst (availableSubsets s) Array.! k : usedSubsets s))


mvvec :: MultiValue.T (LLVM.Vector n a) -> MultiVector.T n a
mvvec (MultiValue.Cons x) = MultiVector.Cons x

extract ::
   (TypeNum.Positive n, MultiVector.C a) =>
   Exp CounterId -> Exp (LLVM.Vector n a) -> Exp a
extract =
   Expr.liftM2
      (\(MultiValue.Cons k) v ->
         flip MultiVector.extract (mvvec v) =<< LLVM.zext k)

extractBlock :: Exp CounterId -> Exp Block -> Exp Subblock
extractBlock =
   Expr.liftM2
      (\(MultiValue.Cons k) (MultiValue.Cons v) ->
         MultiValue.Cons <$> (LLVM.extractelement v =<< LLVM.zext k))

flattenCounters ::
   Symb.Array (BlockDim, BitDim) Counters ->
   Symb.Array ((BlockDim,CounterDim), BitDim) Counter
flattenCounters =
   Symb.mapWithIndex (\ix block -> extract (Expr.snd (Expr.fst ix)) block) .
   Slice.apply (Slice.first extrudeCounters)


bitPos :: Exp BitId -> Exp Subblock
bitPos = Expr.liftM (MultiValue.liftM LLVM.bitcast)

singleBit :: Exp BitId -> Exp Subblock
singleBit = Expr.shl 1 . bitPos


argMin ::
   (MultiValue.Select x, MultiValue.Select y, MultiValue.Comparison y) =>
   Exp (x,y) -> Exp (x,y) -> Exp (x,y)
argMin xy0 xy1 = Expr.select (Expr.snd xy0 <* Expr.snd xy1) xy0 xy1

argMinimum ::
   (Shape.C sh, Shape.Index sh ~ ix, MultiValue.Select ix) =>
   Symb.Array sh Counter -> Exp ix
argMinimum = Expr.fst . Symb.fold1All argMin . Symb.mapWithIndex Expr.zip

_keepMinimum ::
   IO (Phys.Array (BlockDim, BitDim) Counters ->
       IO ((BlockId,CounterId),Counter))
_keepMinimum =
   Render.run $ Expr.mapSnd singleBit . argMinimum . flattenCounters


argMinMasked ::
   (MultiValue.Select x, MultiValue.Select y, MultiValue.Comparison y) =>
   Exp (Bool, (x,y)) -> Exp (Bool, (x,y)) -> Exp (Bool, (x,y))
argMinMasked xy0 xy1 =
   Expr.select (Expr.fst xy1)
      (Expr.select (Expr.fst xy0)
         (Expr.zip Expr.true $ argMin (Expr.snd xy0) (Expr.snd xy1))
         xy1)
      xy0

testBlockBit :: Exp CounterId -> Exp BitId -> Exp Block -> Exp Bool
testBlockBit k j block = Expr.shr (extractBlock k block) (bitPos j) .&.* 1 /=* 0

flattenBlockBits ::
   Symb.Array BlockDim Block ->
   Symb.Array ((BlockDim,CounterDim), BitDim) Bool
flattenBlockBits =
   Symb.mapWithIndex
      (Expr.modify2 ((atom,atom),atom) atom $ \((_n,k),j) block ->
         testBlockBit k j block) .
   Slice.apply (Slice.compose extrudeCounters extrudeBits)

argMinimumMasked ::
   Symb.Array BlockDim Block ->
   Symb.Array ((BlockDim,CounterDim), BitDim) Counter ->
   Exp ((BlockId,CounterId),BitId)
argMinimumMasked free =
   Expr.fst . Expr.snd . Symb.fold1All argMinMasked .
   Symb.zip (flattenBlockBits free) . Symb.mapWithIndex Expr.zip

_keepMinimumMasked ::
   IO (Phys.Array BlockDim Block ->
       Phys.Array (BlockDim,BitDim) Counters ->
       IO ((BlockId,CounterId),Counter))
_keepMinimumMasked =
   Render.run $ \free ->
      Expr.mapSnd singleBit . argMinimumMasked free . flattenCounters


argMinVec ::
   (TypeNum.Positive n,
    MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
   Exp (LLVM.Vector n (x,y)) -> Exp (LLVM.Vector n (x,y)) ->
   Exp (LLVM.Vector n (x,y))
argMinVec xy0 xy1 =
   ExprVec.select
      (ExprVec.cmp LLVM.CmpLT (ExprVec.snd xy0) (ExprVec.snd xy1)) xy0 xy1

argMinMaskedVec ::
   (TypeNum.Positive n,
    MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
   Exp (LLVM.Vector n (Bool, (x,y))) -> Exp (LLVM.Vector n (Bool, (x,y))) ->
   Exp (LLVM.Vector n (Bool, (x,y)))
argMinMaskedVec xy0 xy1 =
   ExprVec.select (ExprVec.fst xy1)
      (ExprVec.select (ExprVec.fst xy0)
         (ExprVec.zip (ExprVec.replicate Expr.true) $
          argMinVec (ExprVec.snd xy0) (ExprVec.snd xy1))
         xy1)
      xy0

testBlockBitVec ::
   Exp BitId -> Exp Block -> Exp (LLVM.Vector NumCounters Bool)
testBlockBitVec j block =
   ExprVec.cmp LLVM.CmpNE Expr.zero $ pickBitsX86 j block

flattenBlockBitsVec ::
   Symb.Array BlockDim Block ->
   Symb.Array (BlockDim,BitDim) (LLVM.Vector NumCounters Bool)
flattenBlockBitsVec =
   Symb.mapWithIndex
      (Expr.modify2 (atom,atom) atom $ \(_n,j) block ->
         testBlockBitVec j block) .
   Slice.apply extrudeBits

argMinimumMaskedVec ::
   Symb.Array BlockDim Block ->
   Symb.Array (BlockDim, BitDim) Counters ->
   Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter)))
argMinimumMaskedVec free =
   Symb.fold1All argMinMaskedVec .
   Symb.zipWith ExprVec.zip (flattenBlockBitsVec free) .
   Symb.mapWithIndex (ExprVec.zip . ExprVec.replicate)

counterIds :: Exp (LLVM.Vector NumCounters CounterId)
counterIds = ExprVec.cons (LLVM.vector (NonEmptyC.iterate (1+) 0))

_keepMinimumMaskedVector ::
   Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter))) ->
   Exp ((BlockId, CounterId), BitId)
_keepMinimumMaskedVector =
   Expr.liftM
      (fmap (MultiValue.fst . MultiValue.snd) .
       foldM (Expr.unliftM2 argMinMasked)
         (MultiValue.zip (MultiValue.cons False) MultiValue.undef)
       <=< MultiValueVec.dissect)
   .
   ExprVec.mapSnd
      (ExprVec.mapFst (ExprVec.mapFst (flip ExprVec.zip counterIds)))

type
   IxVector n =
      MultiValue.T (LLVM.Vector n
         (Bool, (((BlockId, CounterId), BitId), Counter)))

argMinMaskedVecHalf ::
   (TypeNum.Positive n, TypeNum.Positive n2, (n:+:n) ~ n2,
    MultiVector.Select x, MultiVector.Select y, MultiVector.Comparison y) =>
   MultiValue.T (LLVM.Vector n2 (Bool, (x, y))) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n (Bool, (x, y))))
argMinMaskedVecHalf x =
   Monad.liftJoin2
      (Expr.unliftM2 argMinMaskedVec)
      (MultiValueVec.take x)
      (MultiValueVec.takeRev x)

keepMinimumMaskedCascade ::
   Exp (LLVM.Vector NumCounters (Bool, ((BlockId, BitId), Counter))) ->
   Exp ((BlockId, CounterId), BitId)
keepMinimumMaskedCascade =
   Expr.fst . Expr.snd
   .
   Expr.liftM
      (\x16 -> do
         x8 <- argMinMaskedVecHalf x16
         x4 <- argMinMaskedVecHalf x8
         x2 <- argMinMaskedVecHalf x4
         Monad.liftJoin2 (Expr.unliftM2 argMinMasked)
            (MultiValueVec.extract (LLVM.valueOf 0) (x2 :: IxVector TypeNum.D2))
            (MultiValueVec.extract (LLVM.valueOf 1) x2))
   .
   ExprVec.mapSnd
      (ExprVec.mapFst (ExprVec.mapFst (flip ExprVec.zip counterIds)))

{- |
In general this function will choose a different minimal element
than 'keepMinimumMasked'.
-}
keepMinimumMaskedVec ::
   IO (Phys.Array BlockDim Block ->
       Phys.Array (BlockDim, BitDim) Counters ->
       IO ((BlockId,CounterId),Subblock))
keepMinimumMaskedVec =
   Render.run $ \free ->
      Expr.mapSnd singleBit . keepMinimumMaskedCascade .
      argMinimumMaskedVec free


affectedRows ::
   IO (Phys.Array (SetDim,BlockDim) Block ->
       ((BlockId,CounterId),Subblock) -> IO [SetId])
affectedRows = do
   affected <-
      Render.run $ \arr ((j,k),bit) ->
         findIndices $ Symb.map (Expr.not . disjoint bit . extractBlock k) $
         Slice.apply (Slice.pickSnd j) $ Symb.fix arr
   return $ \arr bit -> Phys.toList =<< affected arr bit


minimize ::
   IO (Phys.Array BlockDim Block ->
       Phys.Array (SetDim,BlockDim) Block -> IO [SetId])
minimize = do
   smRows <- Render.run (sumRows . uninterleaveBits)
   affected <- affectedRows
   keepMin <- keepMinimumMaskedVec
   return $ \free arr -> affected arr =<< keepMin free =<< smRows arr

stepIO :: IO (State label -> LazyIO.T [State label])
stepIO = do
   update <- updateStateIO
   minim <- minimize
   return $ \s ->
      mapM (flip update s) =<<
      LazyIO.interleave (minim (freeElements s) (snd $ availableSubsets s))

searchIO :: IO (State label -> LazyIO.T [[label]])
searchIO = do
   stp <- stepIO
   nullSt <- Render.run (Expr.bool8FromP . nullSet)
   let srch s = do
         isNull <- LazyIO.interleave $ nullSt (freeElements s)
         if Bool8.toBool isNull
           then return [usedSubsets s]
           else concat <$> (mapM srch =<< stp s)
   return srch

partitionsIO :: (Ord a) => IO ([ESC.Assign label (Set a)] -> LazyIO.T [[label]])
partitionsIO = do
   srch <- searchIO
   return $ srch <=< LazyIO.interleave . initStateIO

partitions :: (Ord a) => [ESC.Assign label (Set a)] -> [[label]]
partitions =
   let parts = unsafePerformIO partitionsIO
   in  unsafePerformIO . LazyIO.run . parts