module Data.PrimitiveArray.Index.Set where
import           Control.Applicative ((<$>),(<*>))
import           Control.DeepSeq (NFData(..))
import           Data.Aeson (FromJSON,ToJSON)
import           Data.Binary (Binary)
import           Data.Bits
import           Data.Bits.Extras
import           Data.Serialize (Serialize)
import           Data.Vector.Fusion.Stream.Size
import           Data.Vector.Unboxed.Deriving
import           Data.Vector.Unboxed (Unbox(..))
import           Debug.Trace
import           GHC.Generics
import qualified Data.Vector.Fusion.Stream.Monadic as SM
import qualified Data.Vector.Unboxed as VU
import           Test.QuickCheck (Arbitrary(..), choose, elements)
import           Data.Bits.Ordered
import           Data.PrimitiveArray.Index.Class
newtype Interface t = Iter { getIter :: Int }
  deriving (Eq,Ord,Read,Show,Generic,Num)
data First
data Last
data Any
newtype BitSet = BitSet { getBitSet :: Int }
  deriving (Eq,Ord,Read,Generic,FiniteBits,Ranked,Num,Bits)
type BS1I i = BitSet:>Interface i
type BS2I i j = BitSet:>Interface i:>Interface j
class SetPredSucc s where
  
  
  setSucc :: s -> s -> s -> Maybe s
  
  
  setPred :: s -> s -> s -> Maybe s
type family Mask s :: *
data Fixed t = Fixed { getFixedMask :: (Mask t) , getFixed :: !t }
class ApplyMask s where
  applyMask :: Mask s -> s -> s
derivingUnbox "Interface"
  [t| forall t . Interface t -> Int |]
  [| \(Iter i) -> i            |]
  [| Iter                      |]
instance Binary    (Interface t)
instance Serialize (Interface t)
instance ToJSON    (Interface t)
instance FromJSON  (Interface t)
instance NFData (Interface t) where
  rnf (Iter i) = rnf i
  
instance Index (Interface i) where
  linearIndex l _ (Iter z) = z  smallestLinearIndex l
  
  smallestLinearIndex (Iter l) = l
  
  largestLinearIndex (Iter h) = h
  
  size (Iter l) (Iter h) = h  l + 1
  
  inBounds l h z = l <= z && z <= h
  
derivingUnbox "BitSet"
  [t| BitSet     -> Int |]
  [| \(BitSet s) -> s   |]
  [| BitSet             |]
instance Show BitSet where
  show (BitSet s) = "<" ++ (show $ activeBitsL s) ++ ">(" ++ show s ++ ")"
instance Binary    BitSet
instance Serialize BitSet
instance ToJSON    BitSet
instance FromJSON  BitSet
instance NFData BitSet where
  rnf (BitSet s) = rnf s
  
instance Index BitSet where
  linearIndex l _ (BitSet z) = z  smallestLinearIndex l 
  
  smallestLinearIndex l = 2 ^ popCount l  1
  
  largestLinearIndex h = 2 ^ popCount h  1
  
  size l h = 2 ^ popCount h  2 ^ popCount l + 1
  
  inBounds l h z = popCount l <= popCount z && popCount z <= popCount h
  
instance IndexStream z => IndexStream (z:.BitSet) where
  streamUp (ls:.l) (hs:.h) = SM.flatten mk step Unknown $ streamUp ls hs
    where mk z = return (z , (if l <= h then Just l else Nothing))
          step (z , Nothing) = return $ SM.Done
          step (z , Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
          
          
  
  streamDown (ls:.l) (hs:.h) = SM.flatten mk step Unknown $ streamDown ls hs
    where mk z = return (z :. (if l <= h then Just h else Nothing))
          step (z :. Nothing) = return $ SM.Done
          step (z :. Just t ) = return $ SM.Yield (z:.t) (z :. setPred l h t)
          
          
  
instance IndexStream z => IndexStream (z:.(BitSet:>Interface i)) where
  streamUp (ls:.l@(sl:>_)) (hs:.h@(sh:>_)) = SM.flatten mk step Unknown $ streamUp ls hs
    where mk z = return (z, (if sl<=sh then Just (sl:>(Iter . max 0 $ lsbZ sl)) else Nothing))
          step (z , Nothing) = return $ SM.Done
          step (z,  Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
          
          
  
  streamDown (ls:.l@(sl:>_)) (hs:.h@(sh:>_)) = SM.flatten mk step Unknown $ streamDown ls hs
    where mk z = return (z, (if sl<=sh then Just (sh:>(Iter . max 0 $ lsbZ sh)) else Nothing))
          step (z , Nothing) = return $ SM.Done
          step (z , Just t ) = return $ SM.Yield (z:.t) (z , setPred l h t)
          
          
  
instance IndexStream z => IndexStream (z:.(BitSet:>Interface i:>Interface j)) where
  streamUp (ls:.l@(sl:>_:>_)) (hs:.h@(sh:>_:>_)) = SM.flatten mk step Unknown $ streamUp ls hs
    where mk z | sl > sh   = return (z , Nothing)
               | cl == 0   = return (z , Just (0:>0:>0))
               | cl == 1   = let i = lsbZ sl
                             in  return (z , Just (sl :> Iter i :> Iter i))
               | otherwise = let i = lsbZ sl; j = lsbZ (sl `clearBit` i)
                             in  return (z , Just (sl :> Iter i :> Iter j))
               where cl = popCount sl
          step (z , Nothing) = return $ SM.Done
          step (z , Just t ) = return $ SM.Yield (z:.t) (z , setSucc l h t)
          
          
  
  streamDown (ls:.l@(sl:>_:>_)) (hs:.h@(sh:>_:>_)) = SM.flatten mk step Unknown $ streamDown ls hs
    where mk z | sl > sh   = return (z , Nothing)
               | ch == 0   = return (z , Just (0:>0:>0))
               | ch == 1   = let i = lsbZ sh
                             in  return (z , Just (sh :> Iter i :> Iter i))
               | otherwise = let i = lsbZ sh; j = lsbZ sh
                             in  return (z , Just (sh :> Iter i :> Iter j))
               where ch = popCount sh
          step (z , Nothing) = return $ SM.Done
          step (z , Just t ) = return $ SM.Yield (z:.t) (z , setPred l h t)
          
          
  
instance SetPredSucc BitSet where
  setSucc l h s
    | cs > ch                        = Nothing
    | Just s' <- popPermutation ch s = Just s'
    | cs >= ch                       = Nothing
    | cs < ch                        = Just . BitSet $ 2^(cs+1) 1
    where ch = popCount h
          cs = popCount s
  
  setPred l h s
    | cs < cl                        = Nothing
    | Just s' <- popPermutation ch s = Just s'
    | cs <= cl                       = Nothing
    | cs > cl                        = Just . BitSet $ 2^(cs1) 1
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
instance SetPredSucc (BitSet:>Interface i) where
  setSucc (l:>il) (h:>ih) (s:>Iter is)
    | cs > ch                         = Nothing
    | Just is' <- maybeNextActive is s     = Just (s:>Iter is')
    | Just s'  <- popPermutation ch s = Just (s':>Iter (lsbZ s'))
    | cs >= ch                        = Nothing
    | cs < ch                         = let s' = BitSet $ 2^(cs+1)1 in Just (s' :> Iter (lsbZ s'))
    where ch = popCount h
          cs = popCount s
  
  setPred (l:>il) (h:>ih) (s:>Iter is)
    | cs < cl                         = Nothing
    | Just is' <- maybeNextActive is s     = Just (s:>Iter is')
    | Just s'  <- popPermutation ch s = Just (s':>Iter (lsbZ s'))
    | cs <= cl                        = Nothing
    | cs > cl                         = let s' = BitSet $ 2^(cs1)1 in Just (s' :> Iter (max 0 $ lsbZ s'))
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
instance SetPredSucc (BitSet:>Interface i:>Interface j) where
  setSucc (l:>il:>jl) (h:>ih:>jh) (s:>Iter is:>Iter js)
    
    | cs > ch                         = Nothing
    
    
    | cs == 0                         = Just (1:>0:>0)
    
    
    | cs == 1
    , Just s'  <- popPermutation ch s
    , let is' = lsbZ s'          = Just (s':>Iter is':>Iter is')
    
    | Just js' <- maybeNextActive js (s `clearBit` is) = Just (s:>Iter is:>Iter js')
    
    | Just is' <- maybeNextActive is s
    , let js' = lsbZ (s `clearBit` is')      = Just (s:>Iter is':>Iter js')
    
    | Just s'  <- popPermutation ch s
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (s':>Iter is':>Iter js')
    
    | cs >= ch                        = Nothing
    
    | cs < ch
    , let s' = BitSet $ 2^(cs+1)1
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (s':>Iter is':>Iter js')
    where ch = popCount h
          cs = popCount s
  
  setPred (l:>il:>jl) (h:>ih:>jh) (s:>Iter is:>Iter js)
    
    | cs < cl                         = Nothing
    
    
    | cs == 0                         = Nothing
    
    
    | cs == 1
    , Just s'  <- popPermutation ch s
    , let is' = lsbZ s'          = Just (s':>Iter is':>Iter is')
    
    | cs == 1                         = Just (0:>0:>0)
    
    | Just js' <- maybeNextActive js (s `clearBit` is) = Just (s:>Iter is:>Iter js')
    
    | Just is' <- maybeNextActive is s
    , let js' = lsbZ (s `clearBit` is')      = Just (s:>Iter is':>Iter js')
    
    | Just s'  <- popPermutation ch s
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (s':>Iter is':>Iter js')
    
    | cs <= cl                        = Nothing
    
    | cs > cl && cs > 2
    , let s' = BitSet $ 2^(cs1)1
    , let is' = lsbZ s'
    , Just js' <- maybeNextActive is' s'   = Just (s':>Iter is':>Iter js')
    
    | cs > cl && cs == 2              = Just (1:>0:>0)
    where cl = popCount l
          ch = popCount h
          cs = popCount s
  
type instance Mask BitSet = BitSet
type instance Mask (BitSet :> Interface i) = BitSet
type instance Mask (BitSet :> Interface i :> Interface j) = BitSet
derivingUnbox "Fixed"
  [t| forall t . (Unbox t, Unbox (Mask t)) => Fixed t -> (Mask t, t) |]
  [| \(Fixed m s) -> (m,s)              |]
  [| uncurry Fixed                      |]
deriving instance (Eq t     , Eq      (Mask t)) => Eq      (Fixed t)
deriving instance (Ord t    , Ord     (Mask t)) => Ord     (Fixed t)
deriving instance (Read t   , Read    (Mask t)) => Read    (Fixed t)
deriving instance (Show t   , Show    (Mask t)) => Show    (Fixed t)
deriving instance (Generic t, Generic (Mask t)) => Generic (Fixed t)
instance (Generic t, Generic (Mask t), Binary t   , Binary    (Mask t)) => Binary    (Fixed t)
instance (Generic t, Generic (Mask t), Serialize t, Serialize (Mask t)) => Serialize (Fixed t)
instance NFData (Fixed t) where
  rnf (Fixed m s) = m `seq` s `seq` ()
testBsS :: BitSet -> Maybe (Fixed BitSet)
testBsS k = setSucc (Fixed 0 0) (Fixed 0 7) (Fixed 4 k)
instance SetPredSucc (Fixed BitSet) where
  setPred (Fixed _ l) (Fixed _ h) (Fixed !m s) = Fixed m <$> setPred l h (s .&. complement m)
  
  
  
  
  
  
  
  
  
  
  setSucc (Fixed _ l) (Fixed _ h) (Fixed !m' s) = traceShow (h,m,s,' ',fb0,fb1,' ',p',p'',p) $ (Fixed m . (.|. fb1)) <$> p
    where m   = m' .&. h
          fb0 = m  .&. complement s
          fb1 = m  .&. s
          p'  = popShiftR m s
          p'' = setSucc (popShiftR m l) (popShiftR m h) p'
          p   = popShiftL m <$> p''
  
instance SetPredSucc (Fixed (BitSet:>Interface i)) where
  setPred (Fixed _ (l:>li)) (Fixed _ (h:>hi)) (Fixed !m (s:>i))
    | s `testBit` getIter i = (Fixed m . (:> i) . ( `setBit` getIter i)) <$> setPred l h (s .&. complement m)
    | otherwise             = (Fixed m) <$> setPred (l:>li) (h:>hi) ((s .&. complement m):>i)
  
  setSucc (Fixed _ (l:>li)) (Fixed _ (h:>hi)) (Fixed !m (s:>i))
    | s `testBit` getIter i = (Fixed m . (:> i) . ( `setBit` getIter i)) <$> setSucc l h (s .&. complement m)
    | otherwise             = (Fixed m) <$> setSucc (l:>li) (h:>hi) ((s .&. complement m):>i)
  
instance SetPredSucc (Fixed (BitSet:>Interface i:>Interface j)) where
  setPred (Fixed _ (l:>li:>lj)) (Fixed _ (h:>hi:>hj)) (Fixed !m (s:>i:>j))
    | s `testBit` getIter i && s `testBit` getIter j
    = (Fixed m . (\z       -> (z `setBit` getIter i `setBit` getIter j:>i:>j ))) <$> setPred l h (s .&. complement m)
    | s `testBit` getIter i
    = (Fixed m . (\(z:>j') -> (z `setBit` getIter i                   :>i:>j'))) <$> setPred (l:>lj) (h:>hj) (s .&. complement m :>j)
    | s `testBit` getIter j
    = (Fixed m . (\(z:>i') -> (z `setBit` getIter j                   :>i':>j))) <$> setPred (l:>li) (h:>hi) (s .&. complement m :>i)
  
  setSucc (Fixed _ (l:>li:>lj)) (Fixed _ (h:>hi:>hj)) (Fixed !m (s:>i:>j))
    | s `testBit` getIter i && s `testBit` getIter j
    = (Fixed m . (\z       -> (z `setBit` getIter i `setBit` getIter j:>i:>j ))) <$> setSucc l h (s .&. complement m)
    | s `testBit` getIter i
    = (Fixed m . (\(z:>j') -> (z `setBit` getIter i                   :>i:>j'))) <$> setSucc (l:>lj) (h:>hj) (s .&. complement m :>j)
    | s `testBit` getIter j
    = (Fixed m . (\(z:>i') -> (z `setBit` getIter j                   :>i':>j))) <$> setSucc (l:>li) (h:>hi) (s .&. complement m :>i)
  
instance ApplyMask BitSet where
  applyMask = popShiftL
  
instance ApplyMask (BitSet :> Interface i) where
  applyMask m (s:>i)
    | popCount s == 0 = 0:>0
    | otherwise       = popShiftL m s :> (Iter . getBitSet . popShiftL m . BitSet $ 2 ^ getIter i)
  
instance ApplyMask (BitSet :> Interface i :> Interface j) where
  applyMask m (s:>i:>j)
    | popCount s == 0 = 0:>0:>0
    | popCount s == 1 = s' :> i' :> Iter (getIter i')
    | otherwise       = s' :> i' :> j'
    where s' = popShiftL m s
          i' = Iter . getBitSet . popShiftL m . BitSet $ 2 ^ getIter i
          j' = Iter . getBitSet . popShiftL m . BitSet $ 2 ^ getIter j
  
arbitraryBitSetMax = 6
instance (Arbitrary t, Arbitrary (Mask t)) => Arbitrary (Fixed t) where
  arbitrary = Fixed <$> arbitrary <*> arbitrary
  shrink (Fixed m s) = [ Fixed m' s' | m' <- shrink m, s' <- shrink s ]
instance Arbitrary BitSet where
  arbitrary = BitSet <$> choose (0,2^arbitraryBitSetMax1)
  shrink s = let s' = [ s `clearBit` a | a <- activeBitsL s ]
             in  s' ++ concatMap shrink s'
instance Arbitrary (BitSet:>Interface i) where
  arbitrary = do
    s <- arbitrary
    if s==0
      then return (s:>Iter 0)
      else do i <- elements $ activeBitsL s
              return (s:>Iter i)
  shrink (s:>i) =
    let s' = [ (s `clearBit` a:>i)
             | a <- activeBitsL s
             , Iter a /= i ]
             ++ [ 0 :> Iter 0 | popCount s == 1 ]
    in  s' ++ concatMap shrink s'
instance Arbitrary (BitSet:>Interface i:>Interface j) where
  arbitrary = do
    s <- arbitrary
    case (popCount s) of
      0 -> return (s:>Iter 0:>Iter 0)
      1 -> do i <- elements $ activeBitsL s
              return (s:>Iter i:>Iter i)
      _ -> do i <- elements $ activeBitsL s
              j <- elements $ activeBitsL (s `clearBit` i)
              return (s:>Iter i:>Iter j)
  shrink (s:>i:>j) =
    let s' = [ (s `clearBit` a:>i:>j)
             | a <- activeBitsL s
             , Iter a /= i, Iter a /= j ]
             ++ [ 0 `setBit` a :> Iter a :> Iter a
                | popCount s == 2
                , a <- activeBitsL s ]
             ++ [ 0 :> Iter 0 :> Iter 0
                | popCount s == 1 ]
    in  s' ++ concatMap shrink s'