{-# OPTIONS_HADDOCK hide #-} module PopKey.Internal2 where import Control.Monad.ST import Data.Bit as B import qualified Data.ByteString as BS import Data.Either import Data.Foldable import HaskellWorks.Data.RankSelect.CsPoppy import qualified Data.Vector.Storable as SV import qualified Data.Vector.Unboxed as UV import qualified Data.Vector.Unboxed.Mutable as MUV import GHC.Word import HaskellWorks.Data.Bits.BitWise ((.?.)) import HaskellWorks.Data.RankSelect.Base.Rank0 import HaskellWorks.Data.RankSelect.Base.Rank1 import Unsafe.Coerce import PopKey.Internal1 data F s a where Single :: !a -> F () a Prod :: !(F s1 a) -> !(F s2 a) -> F (s1 , s2) a Sum :: {-# UNPACK #-} !Word32 -> CsPoppy -> !(F s1 a) -> !(F s2 a) -> F (Either s1 s2) a -- cardinality / poppy ; poppy undefined if cardinality = 0 -- 0 indicates storage in the left / 1 indicates storage in the right data F' s a where Single' :: a -> F' () a Prod' :: (F' s1 a) -> (F' s2 a) -> F' (s1 , s2) a Sum' :: (Either (F' s1 a) (F' s2 a)) -> F' (Either s1 s2) a instance Eq a => Eq (F' s a) where {-# INLINEABLE (==) #-} (==) (Single' x) (Single' y) = x == y (==) (Prod' x1 y1) (Prod' x2 y2) = (x1 == x2) && (y1 == y2) (==) (Sum' s1) (Sum' s2) = s1 == s2 instance Ord a => Ord (F' s a) where {-# INLINABLE (<=) #-} (<=) (Single' x) (Single' y) = x <= y (<=) (Prod' x1 y1) (Prod' x2 y2) = (x1 , y1) <= (x2 , y2) (<=) (Sum' s1) (Sum' s2) = s1 <= s2 flength :: F s PKPrim -> Int flength (Single a) = pkLength a flength (Prod x _) = flength x flength (Sum l _ _ _) = fromIntegral l data I s where ISingle :: I () IProd :: I s1 -> I s2 -> I (s1 , s2) ISum :: I s1 -> I s2 -> I (Either s1 s2) -- index must be valid rawq :: Int -> F s PKPrim -> F' s BS.ByteString rawq i = go where go :: F s PKPrim -> F' s BS.ByteString go (Single pk) = Single' (pkIndex pk i) go (Prod x y) = Prod' (go x) (go y) go (Sum _ pk l r) = do let b1pos = fromIntegral i if pk .?. b1pos -- nasty! this uses 0-based indexing, while rank/select use 1-based indexing then Sum' (Right (rawq (fromIntegral (rank1 pk (fromIntegral b1pos))) r)) else Sum' (Left (rawq (fromIntegral (rank0 pk (fromIntegral b1pos))) l)) -- returns @-1@ if not found {-# INLINABLE bin_search2 #-} bin_search2 :: F s PKPrim -> F' s BS.ByteString -> Int -> Int -> Int bin_search2 vs q = go where go :: Int -> Int -> Int go l r | r >= l = do let m = l + (r - l) `div` 2 p = rawq m vs if p > q then go l (m - 1) else if p == q then m else go (m + 1) r | otherwise = -1 {-# INLINE query #-} query :: forall a s . (F' s BS.ByteString -> a) -> F s PKPrim -> Int -> a query d pk i = d (rawq i pk) {-# INLINABLE construct #-} construct :: forall a s f . Foldable f => I s -> (a -> F' s BS.ByteString) -> f a -> F s PKPrim construct = \s e f -> if length f == 0 then fancyZero s else go s (foldr ((:) . e) mempty f) where fancyZero :: forall t . I t -> F t PKPrim fancyZero ISingle = Single (ConstSize mempty 0 0) fancyZero (IProd x y) = Prod (fancyZero x) (fancyZero y) fancyZero (ISum x y) = Sum 0 undefined (fancyZero x) (fancyZero y) go :: forall t . I t -> [ F' t BS.ByteString ] -> F t PKPrim go ISingle = \ys -> Single (makePK (fromSingle <$> ys)) where fromSingle :: F' () BS.ByteString -> BS.ByteString fromSingle (Single' x) = x go (IProd s1 s2) = \ys -> do let (as , bs) = unzip (fromProd <$> ys) Prod (go s1 as) (go s2 bs) where fromProd :: forall s1 s2 . F' (s1 , s2) BS.ByteString -> (F' s1 BS.ByteString , F' s2 BS.ByteString) fromProd (Prod' x y) = (x , y) go (ISum s1 s2) = \ys -> do let zs = fromSum <$> ys l = length ys bv :: UV.Vector Bit = runST do v <- MUV.new l for_ (zip [ 0 .. ] zs) \(i,x) -> case x of Left _ -> MUV.unsafeWrite v i 0 Right _ -> MUV.unsafeWrite v i 1 UV.unsafeFreeze v uv64 :: UV.Vector Word64 = unsafeCoerce do cloneToWords $ bv sv64 :: SV.Vector Word64 = SV.convert uv64 !(ppy :: CsPoppy) = makeCsPoppy sv64 (as , bs) = partitionEithers zs Sum (fromIntegral l) ppy (f s1 as) (f s2 bs) where f s [] = fancyZero s f s xs = go s xs fromSum :: forall s1 s2 . F' (Either s1 s2) BS.ByteString -> Either (F' s1 BS.ByteString) (F' s2 BS.ByteString) fromSum (Sum' x) = x