module Data.SearchEngine.TermBag (
TermId(..), TermCount,
TermBag,
size,
fromList,
toList,
elems,
termCount,
denseTable,
invariant
) where
import qualified Data.Vector.Unboxed as Vec
import qualified Data.Vector.Unboxed.Mutable as MVec
import qualified Data.Vector.Generic.Base as VecGen
import qualified Data.Vector.Unboxed.Base as VecBase
import qualified Data.Vector.Generic.Mutable as VecMut
import Control.Monad.ST
import qualified Data.Map as Map
import Data.Word (Word32, Word8)
import Data.Bits
import Data.List (sortBy, foldl')
import Data.Function (on)
newtype TermId = TermId Word32
deriving (Eq, Ord, Show, Enum,
Vec.Unbox, VecGen.Vector VecBase.Vector,
VecMut.MVector VecBase.MVector)
instance Bounded TermId where
minBound = TermId 0
maxBound = TermId 0x00FFFFFF
data TermBag = TermBag !Int !(Vec.Vector TermIdAndCount)
deriving Show
type TermIdAndCount = Word32
type TermCount = Word8
termIdAndCount :: TermId -> Int -> TermIdAndCount
termIdAndCount (TermId termid) freq =
(min (fromIntegral freq) 255 `shiftL` 24)
.|. (termid .&. 0x00FFFFFF)
getTermId :: TermIdAndCount -> TermId
getTermId word = TermId (word .&. 0x00FFFFFF)
getTermCount :: TermIdAndCount -> TermCount
getTermCount word = fromIntegral (word `shiftR` 24)
invariant :: TermBag -> Bool
invariant (TermBag _ vec) =
strictlyAscending (Vec.toList vec)
where
strictlyAscending (a:xs@(b:_)) = getTermId a < getTermId b
&& strictlyAscending xs
strictlyAscending _ = True
size :: TermBag -> Int
size (TermBag sz _) = sz
elems :: TermBag -> [TermId]
elems (TermBag _ vec) = map getTermId (Vec.toList vec)
toList :: TermBag -> [(TermId, TermCount)]
toList (TermBag _ vec) = [ (getTermId x, getTermCount x)
| x <- Vec.toList vec ]
termCount :: TermBag -> TermId -> TermCount
termCount (TermBag _ vec) =
binarySearch 0 (Vec.length vec 1)
where
binarySearch :: Int -> Int -> TermId -> TermCount
binarySearch !a !b !key
| a > b = 0
| otherwise =
let mid = (a + b) `div` 2
tidAndCount = vec Vec.! mid
in case compare key (getTermId tidAndCount) of
LT -> binarySearch a (mid1) key
EQ -> getTermCount tidAndCount
GT -> binarySearch (mid+1) b key
fromList :: [TermId] -> TermBag
fromList termids =
let bag = Map.fromListWith (+) [ (t, 1) | t <- termids ]
sz = Map.foldl' (+) 0 bag
vec = Vec.fromListN (Map.size bag)
[ termIdAndCount termid freq
| (termid, freq) <- Map.toAscList bag ]
in TermBag sz vec
denseTable :: [TermBag] -> (Vec.Vector TermId, Vec.Vector TermCount)
denseTable termbags =
(tids, tcts)
where
!numBags = length termbags
!tids = unionsTermId termbags
!numTerms = Vec.length tids
!numCounts = numTerms * numBags
!tcts = Vec.create (do
out <- MVec.new numCounts
sequence_
[ writeMergedTermCounts tids bag out i
| (n, TermBag _ bag) <- zip [0..] termbags
, let i = n * numTerms ]
return out
)
writeMergedTermCounts :: Vec.Vector TermId -> Vec.Vector TermIdAndCount ->
MVec.MVector s TermCount -> Int -> ST s ()
writeMergedTermCounts xs0 ys0 !out i0 =
go xs0 ys0 i0
where
go !xs !ys !i
| Vec.null ys = MVec.set (MVec.slice i (Vec.length xs) out) 0
| Vec.null xs = return ()
| otherwise = let x = Vec.head xs
ytc = Vec.head ys
y = getTermId ytc
c = getTermCount ytc
in case x == y of
True -> do MVec.write out i c
go (Vec.tail xs) (Vec.tail ys) (i+1)
False -> do MVec.write out i 0
go (Vec.tail xs) ys (i+1)
unionsTermId :: [TermBag] -> Vec.Vector TermId
unionsTermId tbs =
case sortBy (compare `on` bagVecLength) tbs of
[] -> Vec.empty
[TermBag _ xs] -> (Vec.map getTermId xs)
(x0:x1:xs) -> foldl' union3 (union2 x0 x1) xs
where
bagVecLength (TermBag _ vec) = Vec.length vec
union2 :: TermBag -> TermBag -> Vec.Vector TermId
union2 (TermBag _ xs) (TermBag _ ys) =
Vec.create (MVec.new sizeBound >>= writeMergedUnion2 xs ys)
where
sizeBound = Vec.length xs + Vec.length ys
writeMergedUnion2 :: Vec.Vector TermIdAndCount -> Vec.Vector TermIdAndCount ->
MVec.MVector s TermId -> ST s (MVec.MVector s TermId)
writeMergedUnion2 xs0 ys0 !out = do
i <- go xs0 ys0 0
return $! MVec.take i out
where
go !xs !ys !i
| Vec.null xs = do Vec.copy (MVec.slice i (Vec.length ys) out)
(Vec.map getTermId ys)
return (i + Vec.length ys)
| Vec.null ys = do Vec.copy (MVec.slice i (Vec.length xs) out)
(Vec.map getTermId xs)
return (i + Vec.length xs)
| otherwise = let x = getTermId (Vec.head xs)
y = getTermId (Vec.head ys)
in case compare x y of
GT -> do MVec.write out i y
go xs (Vec.tail ys) (i+1)
EQ -> do MVec.write out i x
go (Vec.tail xs) (Vec.tail ys) (i+1)
LT -> do MVec.write out i x
go (Vec.tail xs) ys (i+1)
union3 :: Vec.Vector TermId -> TermBag -> Vec.Vector TermId
union3 xs (TermBag _ ys) =
Vec.create (MVec.new sizeBound >>= writeMergedUnion3 xs ys)
where
sizeBound = Vec.length xs + Vec.length ys
writeMergedUnion3 :: Vec.Vector TermId -> Vec.Vector TermIdAndCount ->
MVec.MVector s TermId -> ST s (MVec.MVector s TermId)
writeMergedUnion3 xs0 ys0 !out = do
i <- go xs0 ys0 0
return $! MVec.take i out
where
go !xs !ys !i
| Vec.null xs = do Vec.copy (MVec.slice i (Vec.length ys) out)
(Vec.map getTermId ys)
return (i + Vec.length ys)
| Vec.null ys = do Vec.copy (MVec.slice i (Vec.length xs) out) xs
return (i + Vec.length xs)
| otherwise = let x = Vec.head xs
y = getTermId (Vec.head ys)
in case compare x y of
GT -> do MVec.write out i y
go xs (Vec.tail ys) (i+1)
EQ -> do MVec.write out i x
go (Vec.tail xs) (Vec.tail ys) (i+1)
LT -> do MVec.write out i x
go (Vec.tail xs) ys (i+1)