{-# LANGUAGE
BangPatterns
, FlexibleContexts
, FlexibleInstances
, FunctionalDependencies
, MultiParamTypeClasses
, TypeFamilies
#-}
{-# LANGUAGE Trustworthy #-}
module SAT.Mios.Vec
(
VecFamily (..)
, Vec (..)
, SingleStorage (..)
, Bool'
, Double'
, Int'
, StackFamily (..)
, Stack
, newStackFromList
, realLength
, sortStack
)
where
import qualified Data.Vector.Unboxed.Mutable as UV
import qualified Data.Primitive.ByteArray as BA
import Control.Monad.Primitive
class VecFamily v a | v -> a where
getNth ::v -> Int -> IO a
setNth :: v -> Int -> a -> IO ()
reset:: v -> IO ()
swapBetween :: v -> Int -> Int -> IO ()
modifyNth :: v -> (a -> a) -> Int -> IO ()
newVec :: Int -> a -> IO v
setAll :: v -> a -> IO ()
growBy :: v -> Int -> IO v
asList :: v -> IO [a]
{-# MINIMAL getNth, setNth #-}
reset = errorWithoutStackTrace "no default method: reset"
swapBetween = errorWithoutStackTrace "no default method: swapBetween"
modifyNth = errorWithoutStackTrace "no default method: modifyNth"
newVec = errorWithoutStackTrace "no default method: newVec"
setAll = errorWithoutStackTrace "no default method: setAll"
asList = errorWithoutStackTrace "no default method: asList"
growBy = errorWithoutStackTrace "no default method: growBy"
data family Vec a;
type UVector a = UV.IOVector a
instance VecFamily (UVector Int) Int where
{-# SPECIALIZE INLINE getNth :: UVector Int -> Int -> IO Int #-}
getNth = UV.unsafeRead
{-# SPECIALIZE INLINE setNth :: UVector Int -> Int -> Int -> IO () #-}
setNth = UV.unsafeWrite
{-# SPECIALIZE INLINE modifyNth :: UVector Int -> (Int -> Int) -> Int -> IO () #-}
modifyNth = UV.unsafeModify
{-# SPECIALIZE INLINE swapBetween:: UVector Int -> Int -> Int -> IO () #-}
swapBetween = UV.unsafeSwap
{-# SPECIALIZE INLINE newVec :: Int -> Int -> IO (UVector Int) #-}
newVec n 0 = UV.new n
newVec n x = do v <- UV.new n
UV.set v x
return v
{-# SPECIALIZE INLINE setAll :: UVector Int -> Int -> IO () #-}
setAll = UV.set
{-# SPECIALIZE INLINE growBy :: UVector Int -> Int -> IO (UVector Int) #-}
growBy = UV.unsafeGrow
asList v = mapM (UV.unsafeRead v) [0 .. UV.length v - 1]
data instance Vec Int = ByteArrayInt (BA.MutableByteArray RealWorld)
data instance Vec Double = ByteArrayDouble (BA.MutableByteArray RealWorld)
type ByteArrayInt = Vec Int
type ByteArrayDouble = Vec Double
instance VecFamily ByteArrayInt Int where
{-# SPECIALIZE INLINE getNth :: ByteArrayInt -> Int -> IO Int #-}
getNth (ByteArrayInt v) i = BA.readByteArray v i
{-# SPECIALIZE INLINE setNth :: ByteArrayInt -> Int -> Int -> IO () #-}
setNth (ByteArrayInt v) i x = BA.writeByteArray v i x
{-# SPECIALIZE INLINE modifyNth :: ByteArrayInt -> (Int -> Int) -> Int -> IO () #-}
modifyNth (ByteArrayInt v) f i = BA.writeByteArray v i . f =<< BA.readByteArray v i
{-# SPECIALIZE INLINE swapBetween:: ByteArrayInt -> Int -> Int -> IO () #-}
swapBetween (ByteArrayInt v) i j = do x <- BA.readByteArray v i
y <- BA.readByteArray v j
BA.writeByteArray v i (y :: Int)
BA.writeByteArray v j (x :: Int)
{-# SPECIALIZE INLINE reset :: ByteArrayInt -> IO () #-}
reset (ByteArrayInt v) = BA.writeByteArray v 0 (0 :: Int)
{-# SPECIALIZE INLINE newVec :: Int -> Int -> IO ByteArrayInt #-}
newVec n k = do v <- BA.newByteArray (8 * (n + 1))
BA.writeByteArray v 0 (0 :: Int)
BA.setByteArray v 1 n k
return $ ByteArrayInt v
growBy (ByteArrayInt v) n = do v' <- BA.newByteArray (BA.sizeofMutableByteArray v + 8 * n)
BA.copyMutableByteArray v' 0 v 0 (BA.sizeofMutableByteArray v)
return (ByteArrayInt v')
asList (ByteArrayInt v) = mapM (BA.readByteArray v) [0 .. div (BA.sizeofMutableByteArray v) 8 - 1]
instance VecFamily ByteArrayDouble Double where
{-# SPECIALIZE INLINE getNth :: ByteArrayDouble -> Int -> IO Double #-}
getNth (ByteArrayDouble v) i = BA.readByteArray v i
{-# SPECIALIZE INLINE setNth :: ByteArrayDouble -> Int -> Double -> IO () #-}
setNth (ByteArrayDouble v) i x = BA.writeByteArray v i x
{-# SPECIALIZE INLINE modifyNth :: ByteArrayDouble -> (Double -> Double) -> Int -> IO () #-}
modifyNth (ByteArrayDouble v) f i = BA.writeByteArray v i . f =<< BA.readByteArray v i
{-# SPECIALIZE INLINE swapBetween:: ByteArrayDouble -> Int -> Int -> IO () #-}
swapBetween (ByteArrayDouble v) i j = do x <- BA.readByteArray v i
y <- BA.readByteArray v j
BA.writeByteArray v i (y :: Int)
BA.writeByteArray v j (x :: Int)
{-# SPECIALIZE INLINE reset :: ByteArrayDouble -> IO () #-}
reset (ByteArrayDouble v) = BA.writeByteArray v 0 (0 :: Double)
{-# SPECIALIZE INLINE newVec :: Int -> Double -> IO ByteArrayDouble #-}
newVec n k = do v <- BA.newByteArray (8 * (n + 1))
BA.writeByteArray v 0 (0 :: Double)
BA.setByteArray v 1 n k
return $ ByteArrayDouble v
asList (ByteArrayDouble v) = mapM (BA.readByteArray v) [0 .. div (BA.sizeofMutableByteArray v) 8 - 1]
{-# INLINE realLength #-}
realLength :: Vec Int -> Int
realLength (ByteArrayInt v) = div (BA.sizeofMutableByteArray v) 8
class SingleStorage s t | s -> t where
new' :: t -> IO s
get' :: s -> IO t
set' :: s -> t -> IO ()
modify' :: s -> (t -> t) -> IO ()
{-# MINIMAL get', set' #-}
new' = undefined
modify' = undefined
type Bool' = UV.IOVector Bool
instance SingleStorage Bool' Bool where
{-# SPECIALIZE INLINE new' :: Bool -> IO Bool' #-}
new' k = do s <- UV.new 1
UV.unsafeWrite s 0 k
return s
{-# SPECIALIZE INLINE get' :: Bool' -> IO Bool #-}
get' val = UV.unsafeRead val 0
{-# SPECIALIZE INLINE set' :: Bool' -> Bool -> IO () #-}
set' val !x = UV.unsafeWrite val 0 x
{-# SPECIALIZE INLINE modify' :: Bool' -> (Bool -> Bool) -> IO () #-}
modify' val f = UV.unsafeModify val f 0
type Int' = ByteArrayInt
instance SingleStorage ByteArrayInt Int where
{-# SPECIALIZE INLINE new' :: Int -> IO ByteArrayInt #-}
new' k = do s <- BA.newByteArray 8
BA.writeByteArray s 0 k
return $ ByteArrayInt s
{-# SPECIALIZE INLINE get' :: ByteArrayInt -> IO Int #-}
get' (ByteArrayInt v) = BA.readByteArray v 0
{-# SPECIALIZE INLINE set' :: ByteArrayInt -> Int -> IO () #-}
set' (ByteArrayInt v) !x = BA.writeByteArray v 0 x
{-# SPECIALIZE INLINE modify' :: ByteArrayInt -> (Int -> Int) -> IO () #-}
modify' (ByteArrayInt v) f = BA.writeByteArray v 0 . f =<< BA.readByteArray v 0
type Double' = ByteArrayDouble
instance SingleStorage ByteArrayDouble Double where
{-# SPECIALIZE INLINE new' :: Double -> IO ByteArrayDouble #-}
new' k = do s <- BA.newByteArray 8
BA.writeByteArray s 0 k
return $ ByteArrayDouble s
{-# SPECIALIZE INLINE get' :: ByteArrayDouble -> IO Double #-}
get' (ByteArrayDouble v) = BA.readByteArray v 0
{-# SPECIALIZE INLINE set' :: ByteArrayDouble -> Double -> IO () #-}
set' (ByteArrayDouble v) !x = BA.writeByteArray v 0 x
{-# SPECIALIZE INLINE modify' :: ByteArrayDouble -> (Double -> Double) -> IO () #-}
modify' (ByteArrayDouble v) f = BA.writeByteArray v 0 . f =<< BA.readByteArray v 0
class SingleStorage s Int => StackFamily s t | s -> t where
newStack :: Int -> IO s
pushTo :: s -> t-> IO ()
popFrom :: s -> IO ()
lastOf :: s -> IO t
shrinkBy :: s -> Int -> IO ()
newStack = undefined
pushTo = undefined
popFrom = undefined
lastOf = undefined
shrinkBy = undefined
type Stack = Vec Int
instance StackFamily ByteArrayInt Int where
{-# SPECIALIZE INLINE newStack :: Int -> IO ByteArrayInt #-}
newStack n = do s <- newVec (2 * n) 0
setNth s 0 (0::Int)
return s
{-# SPECIALIZE INLINE pushTo :: ByteArrayInt -> Int -> IO () #-}
pushTo (ByteArrayInt v) x = do i <- (+ 1) <$> (BA.readByteArray v 0 :: IO Int)
BA.writeByteArray v i x
BA.writeByteArray v 0 i
{-# SPECIALIZE INLINE popFrom :: ByteArrayInt -> IO () #-}
popFrom (ByteArrayInt v) = BA.writeByteArray v 0 . subtract 1 =<< (BA.readByteArray v 0 :: IO Int)
{-# SPECIALIZE INLINE lastOf :: ByteArrayInt -> IO Int #-}
lastOf (ByteArrayInt v) = BA.readByteArray v =<< BA.readByteArray v 0
{-# SPECIALIZE INLINE shrinkBy :: ByteArrayInt -> Int -> IO () #-}
shrinkBy (ByteArrayInt v) k = BA.writeByteArray v 0 . subtract k =<< (BA.readByteArray v 0 :: IO Int)
{-# INLINABLE newStackFromList #-}
newStackFromList :: [Int] -> IO Stack
newStackFromList l = do
v <- BA.newByteArray (8 * (length l + 1))
let loop :: [Int] -> Int -> IO Stack
loop [] _ = return $ ByteArrayInt v
loop (x:l') i = BA.writeByteArray v i x >> loop l' (i + 1)
loop (length l : l) 0
{-# INLINABLE sortStack #-}
sortStack :: Stack -> IO ()
sortStack vec = do
n <- get' vec
let sortOnRange :: Int -> Int -> IO ()
sortOnRange !left !right
| n < left = return ()
| right < 1 = return ()
| left >= right = return ()
| left + 1 == right = do
a <- getNth vec left
b <- getNth vec right
if a < b then return () else setNth vec left b >> setNth vec right a
| otherwise = do
let p = div (left + right) 2
pivot <- getNth vec p
swapBetween vec p left
let nextL :: Int -> IO Int
nextL !i
| i <= right = do v <- getNth vec i; if v < pivot then nextL (i + 1) else return i
| otherwise = return i
nextR :: Int -> IO Int
nextR !i = do v <- getNth vec i; if pivot < v then nextR (i - 1) else return i
divide :: Int -> Int -> IO Int
divide !l !r = do
l' <- nextL l
r' <- nextR r
if l' < r' then swapBetween vec l' r' >> divide (l' + 1) (r' - 1) else return r'
m <- divide (left + 1) right
swapBetween vec left m
sortOnRange left (m - 1)
sortOnRange (m + 1) right
sortOnRange 1 n