module Data.PrimitiveArray.Index.Set where
import           Control.Applicative ((<$>),(<*>))
import           Control.DeepSeq (NFData(..))
import           Data.Aeson (FromJSON,ToJSON)
import           Data.Binary (Binary)
import           Data.Bits
import           Data.Bits.Extras
import           Data.Hashable (Hashable)
import           Data.Serialize (Serialize)
import           Data.Vector.Unboxed.Deriving
import           Data.Vector.Unboxed (Unbox(..))
import           Debug.Trace
import           GHC.Generics (Generic)
import qualified Data.Vector.Fusion.Stream.Monadic as SM
import qualified Data.Vector.Unboxed as VU
import           Test.QuickCheck (Arbitrary(..), choose, elements)
import           Data.Bits.Ordered
import           Data.PrimitiveArray.Index.Class
import           Data.PrimitiveArray.Index.IOC
import           Data.PrimitiveArray.Vector.Compat
newtype Interface t = Iter { getIter :: Int }
  deriving (Eq,Ord,Generic,Num)
instance Show (Interface t) where
  show (Iter i) = "(I:" ++ show i ++ ")"
data First
data Last
data Any
newtype BitSet t = BitSet { getBitSet :: Int }
  deriving (Eq,Ord,Read,Generic,FiniteBits,Ranked,Num,Bits)
bitSetI :: Int -> BitSet I
bitSetI = BitSet
bitSetO :: Int -> BitSet O
bitSetO = BitSet
bitSetC :: Int -> BitSet C
bitSetC = BitSet
data BS1 i t = BS1 !(BitSet t) !(Interface i)
deriving instance Show (BS1 i t)
data BS2 i j t = BS2 !(BitSet t) !(Interface i) !(Interface j)
deriving instance Show (BS2 i j t)
class SetPredSucc s where
  
  
  setSucc :: s -> s -> s -> Maybe s
  
  
  setPred :: s -> s -> s -> Maybe s
type family Mask s :: *
data Fixed t = Fixed { getFixedMask :: (Mask t) , getFixed :: !t }
class ApplyMask s where
  applyMask :: Mask s -> s -> s
derivingUnbox "Interface"
  [t| forall t . Interface t -> Int |]
  [| \(Iter i) -> i            |]
  [| Iter                      |]
instance Binary    (Interface t)
instance Serialize (Interface t)
instance ToJSON    (Interface t)
instance FromJSON  (Interface t)
instance Hashable  (Interface t)
instance NFData (Interface t) where
  rnf (Iter i) = rnf i
  
instance Index (Interface i) where
  linearIndex l _ (Iter z) = z  smallestLinearIndex l
  
  smallestLinearIndex (Iter l) = l
  
  largestLinearIndex (Iter h) = h
  
  size (Iter l) (Iter h) = h  l + 1
  
  inBounds l h z = l <= z && z <= h
  
derivingUnbox "BitSet"
  [t| forall t . BitSet t -> Int |]
  [| \(BitSet s) -> s   |]
  [| BitSet             |]
instance Show (BitSet t) where
  show (BitSet s) = "<" ++ (show $ activeBitsL s) ++ ">(" ++ show s ++ ")"
instance Binary    (BitSet t)
instance Serialize (BitSet t)
instance ToJSON    (BitSet t)
instance FromJSON  (BitSet t)
instance Hashable  (BitSet t)
instance NFData (BitSet t) where
  rnf (BitSet s) = rnf s
  
instance Index (BitSet t) where
  linearIndex l _ (BitSet z) = z  smallestLinearIndex l 
  
  smallestLinearIndex l = 2 ^ popCount l  1
  
  largestLinearIndex h = 2 ^ popCount h  1
  
  size l h = 2 ^ popCount h  2 ^ popCount l + 1
  
  inBounds l h z = popCount l <= popCount z && popCount z <= popCount h
  
instance IndexStream z => IndexStream (z:.BitSet I) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsMk   l h) (streamUpBsStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsMk l h) (streamDownBsStep l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BitSet O) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamDownBsMk l h) (streamDownBsStep l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamUpBsMk   l h) (streamUpBsStep   l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BitSet C) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsMk   l h) (streamUpBsStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsMk l h) (streamDownBsStep l h) $ streamDown ls hs
  
  
instance IndexStream (Z:.BitSet t) => IndexStream (BitSet t)
streamUpBsMk :: (Monad m, Ord a) => a -> a -> t -> m (t, Maybe a)
streamUpBsMk l h z = return (z, if l <= h then Just l else Nothing)
streamUpBsStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamUpBsStep l h (z , Nothing) = return $ SM.Done
streamUpBsStep l h (z , Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
streamDownBsMk :: (Monad m, Ord a) => a -> a -> t -> m (t, Maybe a)
streamDownBsMk l h z = return (z, if l <=h then Just h else Nothing)
streamDownBsStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamDownBsStep l h (z , Nothing) = return $ SM.Done
streamDownBsStep l h (z , Just t ) = return $ SM.Yield (z:.t) (z , setPred l h t)
instance Index (BS1 i t) where
  linearIndex (BS1 ls li) (BS1 hs hi) (BS1 s i) = linearIndex (ls:.li) (hs:.hi) (s:.i)
  
  smallestLinearIndex (BS1 s i) = smallestLinearIndex (s:.i)
  
  largestLinearIndex (BS1 s i) = largestLinearIndex (s:.i)
  
  size (BS1 ls li) (BS1 hs hi) = size (ls:.li) (hs:.hi)
  
  inBounds (BS1 ls li) (BS1 hs hi) (BS1 s i) = inBounds (ls:.li) (hs:.hi) (s:.i)
  
instance IndexStream z => IndexStream (z:.BS1 i I) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsIMk   l h) (streamUpBsIStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsIMk l h) (streamDownBsIStep l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BS1 i O) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamDownBsIMk l h) (streamDownBsIStep l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamUpBsIMk   l h) (streamUpBsIStep   l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BS1 i C) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsIMk   l h) (streamUpBsIStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsIMk l h) (streamDownBsIStep l h) $ streamDown ls hs
  
  
instance IndexStream (Z:.BS1 i t) => IndexStream (BS1 i t)
streamUpBsIMk :: (Monad m) => BS1 a i -> BS1 b i -> z -> m (z, Maybe (BS1 c i))
streamUpBsIMk (BS1 sl _) (BS1 sh _) z = return (z, if sl <= sh then Just (BS1 sl (Iter . max 0 $ lsbZ sl)) else Nothing)
streamUpBsIStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamUpBsIStep l h (z , Nothing) = return $ SM.Done
streamUpBsIStep l h (z,  Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
streamDownBsIMk :: (Monad m) => BS1 a i -> BS1 b i -> z -> m (z, Maybe (BS1 c i))
streamDownBsIMk (BS1 sl _) (BS1 sh _) z = return (z, if sl <= sh then Just (BS1 sl (Iter . max 0 $ lsbZ sh)) else Nothing)
streamDownBsIStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamDownBsIStep l h (z , Nothing) = return $ SM.Done
streamDownBsIStep l h (z , Just t ) = return $ SM.Yield (z:.t) (z , setPred l h t)
instance Index (BS2 i j t) where
  linearIndex (BS2 ls li lj) (BS2 hs hi hj) (BS2 s i j) = linearIndex (ls:.li:.lj) (hs:.hi:.hj) (s:.i:.j)
  
  smallestLinearIndex (BS2 s i j) = smallestLinearIndex (s:.i:.j)
  
  largestLinearIndex (BS2 s i j) = largestLinearIndex (s:.i:.j)
  
  size (BS2 ls li lj) (BS2 hs hi hj) = size (ls:.li:.lj) (hs:.hi:.hj)
  
  inBounds (BS2 ls li lj) (BS2 hs hi hj) (BS2 s i j) = inBounds (ls:.li:.lj) (hs:.hi:.hj) (s:.i:.j)
  
instance IndexStream z => IndexStream (z:.BS2 i j I) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsIiMk   l h) (streamUpBsIiStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsIiMk l h) (streamDownBsIiStep l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BS2 i j O) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamDownBsIiMk l h) (streamDownBsIiStep l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamUpBsIiMk   l h) (streamUpBsIiStep   l h) $ streamDown ls hs
  
  
instance IndexStream z => IndexStream (z:.BS2 i j C) where
  streamUp   (ls:.l) (hs:.h) = flatten (streamUpBsIiMk   l h) (streamUpBsIiStep   l h) $ streamUp   ls hs
  streamDown (ls:.l) (hs:.h) = flatten (streamDownBsIiMk l h) (streamDownBsIiStep l h) $ streamDown ls hs
  
  
instance IndexStream (Z:.BS2 i j t) => IndexStream (BS2 i j t)
streamUpBsIiMk :: (Monad m) => BS2 a b i -> BS2 c d i -> z -> m (z, Maybe (BS2 e f i))
streamUpBsIiMk (BS2 sl _ _) (BS2 sh _ _) z
  | sl > sh   = return (z , Nothing)
  | cl == 0   = return (z , Just (BS2 0 0 0))
  | cl == 1   = let i = lsbZ sl
                in  return (z , Just (BS2 sl (Iter i) (Iter i)))
  | otherwise = let i = lsbZ sl; j = lsbZ (sl `clearBit` i)
                in  return (z , Just (BS2 sl (Iter i) (Iter j)))
  where cl = popCount sl
streamUpBsIiStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamUpBsIiStep l h (z , Nothing) = return $ SM.Done
streamUpBsIiStep l h (z , Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
streamDownBsIiMk :: (Monad m) => BS2 a b i -> BS2 c d i -> z -> m (z, Maybe (BS2 e f i))
streamDownBsIiMk (BS2 sl _ _) (BS2 sh _ _) z
  | sl > sh   = return (z , Nothing)
  | ch == 0   = return (z , Just (BS2 0 0 0))
  | ch == 1   = let i = lsbZ sh
                in  return (z , Just (BS2 sh (Iter i) (Iter i)))
  | otherwise = let i = lsbZ sh; j = lsbZ sh
                in  return (z , Just (BS2 sh (Iter i) (Iter j)))
  where ch = popCount sh
streamDownBsIiStep :: (Monad m, SetPredSucc s) => s -> s -> (t, Maybe s) -> m (SM.Step (t, Maybe s) (t :. s))
streamDownBsIiStep l h (z , Nothing) = return $ SM.Done
streamDownBsIiStep l h (z , Just t ) = return $ SM.Yield (z:.t) (z , setPred l h t)
instance SetPredSucc (BitSet t) where
  setSucc l h s
    | cs > ch                        = Nothing
    | Just s' <- popPermutation ch s = Just s'
    | cs >= ch                       = Nothing
    | cs < ch                        = Just . BitSet $ 2^(cs+1) 1
    where ch = popCount h
          cs = popCount s
  
  setPred l h s
    | cs < cl                        = Nothing
    | Just s' <- popPermutation ch s = Just s'
    | cs <= cl                       = Nothing
    | cs > cl                        = Just . BitSet $ 2^(cs1) 1
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
instance SetPredSucc (BS1 i t) where
  setSucc (BS1 l il) (BS1 h ih) (BS1 s (Iter is))
    | cs > ch                          = Nothing
    | Just is' <- maybeNextActive is s = Just $ BS1 s  (Iter is')
    | Just s'  <- popPermutation ch s  = Just $ BS1 s' (Iter $ lsbZ s')
    | cs >= ch                         = Nothing
    | cs < ch                          = let s' = BitSet $ 2^(cs+1)1 in Just (BS1 s' (Iter (lsbZ s')))
    where ch = popCount h
          cs = popCount s
  
  setPred (BS1 l il) (BS1 h ih) (BS1 s (Iter is))
    | cs < cl                          = Nothing
    | Just is' <- maybeNextActive is s = Just $ BS1 s  (Iter is')
    | Just s'  <- popPermutation ch s  = Just $ BS1 s' (Iter  $ lsbZ s')
    | cs <= cl                         = Nothing
    | cs > cl                          = let s' = BitSet $ 2^(cs1)1 in Just (BS1 s' (Iter (max 0 $ lsbZ s')))
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
instance SetPredSucc (BS2 i j t) where
  setSucc (BS2 l il jl) (BS2 h ih jh) (BS2 s (Iter is) (Iter js))
    
    | cs > ch                         = Nothing
    
    
    | cs == 0                         = Just (BS2 1 0 0)
    
    
    | cs == 1
    , Just s'  <- popPermutation ch s
    , let is' = lsbZ s'          = Just (BS2 s' (Iter is') (Iter is'))
    
    | Just js' <- maybeNextActive js (s `clearBit` is) = Just (BS2 s (Iter is) (Iter js'))
    
    | Just is' <- maybeNextActive is s
    , let js' = lsbZ (s `clearBit` is')      = Just (BS2 s (Iter is') (Iter js'))
    
    | Just s'  <- popPermutation ch s
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (BS2 s' (Iter is') (Iter js'))
    
    | cs >= ch                        = Nothing
    
    | cs < ch
    , let s' = BitSet $ 2^(cs+1)1
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (BS2 s' (Iter is') (Iter js'))
    where ch = popCount h
          cs = popCount s
  
  setPred (BS2 l il jl) (BS2 h ih jh) (BS2 s (Iter is) (Iter js))
    
    | cs < cl                         = Nothing
    
    
    | cs == 0                         = Nothing
    
    
    | cs == 1
    , Just s'  <- popPermutation ch s
    , let is' = lsbZ s'          = Just (BS2 s' (Iter is') (Iter is'))
    
    | cs == 1                         = Just (BS2 0 0 0)
    
    | Just js' <- maybeNextActive js (s `clearBit` is) = Just (BS2 s (Iter is) (Iter js'))
    
    | Just is' <- maybeNextActive is s
    , let js' = lsbZ (s `clearBit` is')      = Just (BS2 s (Iter is') (Iter js'))
    
    | Just s'  <- popPermutation ch s
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (BS2 s' (Iter is') (Iter js'))
    
    | cs <= cl                        = Nothing
    
    | cs > cl && cs > 2
    , let s' = BitSet $ 2^(cs1)1
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (BS2 s' (Iter is') (Iter js'))
    
    | cs > cl && cs == 2              = Just (BS2 1 0 0)
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
type instance Mask (BitSet t)  = BitSet t
type instance Mask (BS1 i t)   = BitSet t
type instance Mask (BS2 i j t) = BitSet t
derivingUnbox "Fixed"
  [t| forall t . (Unbox t, Unbox (Mask t)) => Fixed t -> (Mask t, t) |]
  [| \(Fixed m s) -> (m,s)              |]
  [| uncurry Fixed                      |]
deriving instance (Eq t     , Eq      (Mask t)) => Eq      (Fixed t)
deriving instance (Ord t    , Ord     (Mask t)) => Ord     (Fixed t)
deriving instance (Read t   , Read    (Mask t)) => Read    (Fixed t)
deriving instance (Show t   , Show    (Mask t)) => Show    (Fixed t)
deriving instance (Generic t, Generic (Mask t)) => Generic (Fixed t)
instance (Generic t, Generic (Mask t), Hashable t, Hashable (Mask t)) => Hashable (Fixed t)
instance (Generic t, Generic (Mask t), Binary t   , Binary    (Mask t)) => Binary    (Fixed t)
instance (Generic t, Generic (Mask t), Serialize t, Serialize (Mask t)) => Serialize (Fixed t)
instance NFData (Fixed t) where
  rnf (Fixed m s) = m `seq` s `seq` ()
testBsS :: BitSet t -> Maybe (Fixed (BitSet t))
testBsS k = setSucc (Fixed 0 0) (Fixed 0 7) (Fixed 4 k)
instance SetPredSucc (Fixed (BitSet t)) where
  setPred (Fixed _ l) (Fixed _ h) (Fixed !m s) = Fixed m <$> setPred l h (s .&. complement m)
  
  
  
  
  
  
  
  
  
  
  setSucc (Fixed _ l) (Fixed _ h) (Fixed !m' s) = traceShow (h,m,s,' ',fb0,fb1,' ',p',p'',p) $ (Fixed m . (.|. fb1)) <$> p
    where m   = m' .&. h
          fb0 = m  .&. complement s
          fb1 = m  .&. s
          p'  = popShiftR m s
          p'' = setSucc (popShiftR m l) (popShiftR m h) p'
          p   = popShiftL m <$> p''
  
instance SetPredSucc (Fixed (BS1 i t)) where
  setPred (Fixed _ (BS1 l li)) (Fixed _ (BS1 h hi)) (Fixed !m (BS1 s i))
    | s `testBit` getIter i = (Fixed m . (`BS1` i) . ( `setBit` getIter i)) <$> setPred l h (s .&. complement m)
    | otherwise             = (Fixed m) <$> setPred (BS1 l li) (BS1 h hi) (BS1 (s .&. complement m) i)
  
  setSucc (Fixed _ (BS1 l li)) (Fixed _ (BS1 h hi)) (Fixed !m (BS1 s i))
    | s `testBit` getIter i = (Fixed m . (`BS1` i) . ( `setBit` getIter i)) <$> setSucc l h (s .&. complement m)
    | otherwise             = (Fixed m) <$> setSucc (BS1 l li) (BS1 h hi) (BS1 (s .&. complement m) i)
  
instance SetPredSucc (Fixed (BS2 i j t)) where
  setPred (Fixed _ (BS2 l li lj)) (Fixed _ (BS2 h hi hj)) (Fixed !m (BS2 s i j))
    | s `testBit` getIter i && s `testBit` getIter j
    = (Fixed m . (\z       -> BS2 (z `setBit` getIter i `setBit` getIter j) i j)) <$> setPred l h (s .&. complement m)
    | s `testBit` getIter i
    = (Fixed m . (\(BS1 z j') -> BS2 (z `setBit` getIter i) i j')) <$> setPred (BS1 l lj) (BS1 h hj) (BS1 (s .&. complement m) j)
    | s `testBit` getIter j
    = (Fixed m . (\(BS1 z i') -> BS2 (z `setBit` getIter j) i' j)) <$> setPred (BS1 l li) (BS1 h hi) (BS1 (s .&. complement m) i)
  
  setSucc (Fixed _ (BS2 l li lj)) (Fixed _ (BS2 h hi hj)) (Fixed !m (BS2 s i j))
    | s `testBit` getIter i && s `testBit` getIter j
    = (Fixed m . (\z       -> BS2 (z `setBit` getIter i `setBit` getIter j) i j)) <$> setSucc l h (s .&. complement m)
    | s `testBit` getIter i
    = (Fixed m . (\(BS1 z j') -> BS2 (z `setBit` getIter i) i j')) <$> setSucc (BS1 l lj) (BS1 h hj) (BS1 (s .&. complement m) j)
    | s `testBit` getIter j
    = (Fixed m . (\(BS1 z i') -> BS2 (z `setBit` getIter j) i' j)) <$> setSucc (BS1 l li) (BS1 h hi) (BS1 (s .&. complement m) i)
  
instance ApplyMask (BitSet t) where
  applyMask = popShiftL
  
instance ApplyMask (BS1 i t) where
  applyMask m (BS1 s i)
    | popCount s == 0 = BS1 0 0
    | otherwise       = BS1 (popShiftL m s) (Iter . getBitSet . popShiftL m . BitSet $ 2 ^ getIter i)
  
instance ApplyMask (BS2 i j t) where
  applyMask m (BS2 s i j)
    | popCount s == 0 = BS2 0 0 0
    | popCount s == 1 = BS2 s' i' (Iter $ getIter i')
    | otherwise       = BS2 s' i' j'
    where s' = popShiftL m s
          i' = Iter . getBitSet . popShiftL m $ (BitSet $ 2 ^ getIter i :: BitSet t)
          j' = Iter . getBitSet . popShiftL m $ (BitSet $ 2 ^ getIter j :: BitSet t)
  
arbitraryBitSetMax = 6
instance (Arbitrary t, Arbitrary (Mask t)) => Arbitrary (Fixed t) where
  arbitrary = Fixed <$> arbitrary <*> arbitrary
  shrink (Fixed m s) = [ Fixed m' s' | m' <- shrink m, s' <- shrink s ]
instance Arbitrary (BitSet t) where
  arbitrary = BitSet <$> choose (0,2^arbitraryBitSetMax1)
  shrink s = let s' = [ s `clearBit` a | a <- activeBitsL s ]
             in  s' ++ concatMap shrink s'
instance Arbitrary (BS1 i t) where
  arbitrary = do
    s <- arbitrary
    if s==0
      then return (BS1 s 0)
      else do i <- elements $ activeBitsL s
              return (BS1 s $ Iter i)
  shrink (BS1 s i) =
    let s' = [ BS1 (s `clearBit` a) i
             | a <- activeBitsL s
             , Iter a /= i ]
             ++ [ BS1 0 0 | popCount s == 1 ]
    in  s' ++ concatMap shrink s'
instance Arbitrary (BS2 i j t) where
  arbitrary = do
    s <- arbitrary
    case (popCount s) of
      0 -> return (BS2 s 0 0)
      1 -> do i <- elements $ activeBitsL s
              return (BS2 s (Iter i) (Iter i))
      _ -> do i <- elements $ activeBitsL s
              j <- elements $ activeBitsL (s `clearBit` i)
              return (BS2 s (Iter i) (Iter j))
  shrink (BS2 s i j) =
    let s' = [ BS2 (s `clearBit` a) i j
             | a <- activeBitsL s
             , Iter a /= i, Iter a /= j ]
             ++ [ BS2 (0 `setBit` a) (Iter a) (Iter a)
                | popCount s == 2
                , a <- activeBitsL s ]
             ++ [ BS2 0 0 0
                | popCount s == 1 ]
    in  s' ++ concatMap shrink s'