module Data.NeuralNetwork.Backend.BLASHS.Utils (
DenseVector(..),
DenseMatrix(..),
DenseMatrixArray(..),
newDenseVector,
newDenseVectorCopy,
newDenseVectorConst,
newDenseVectorByGen,
newDenseMatrix,
newDenseMatrixConst,
newDenseMatrixCopy,
newDenseMatrixArray,
Size(..),
denseVectorToVector,
denseVectorConcat,
denseVectorSplit,
denseMatrixArrayAt,
denseMatrixArrayToVector,
denseMatrixArrayFromVector,
v2m, m2v, v2ma, ma2v,
Op(..), AssignTo(..),
sumElements, corr2, conv2, pool, unpool, transpose
) where
import Blas.Generic.Unsafe
import Blas.Primitive.Types
import qualified Data.Vector as BV
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Storable.Mutable as V
import qualified Data.Vector.Storable.Internal as V
import Control.Exception
import Control.Monad
import Data.IORef
import Foreign.Marshal.Array (advancePtr)
import Data.NeuralNetwork.Backend.BLASHS.SIMD
newtype DenseVector a = DenseVector (V.IOVector a)
data DenseMatrix a = DenseMatrix !Int !Int !(V.IOVector a)
data DenseMatrixArray a = DenseMatrixArray !Int !Int !Int !(V.IOVector a)
newDenseVector :: V.Storable a => Int -> IO (DenseVector a)
newDenseVector sz = DenseVector <$> V.new sz
newDenseVectorCopy :: V.Storable a => DenseVector a -> IO (DenseVector a)
newDenseVectorCopy (DenseVector v) = V.clone v >>= return . DenseVector
newDenseVectorConst:: V.Storable a => Int -> a -> IO (DenseVector a)
newDenseVectorConst n v = V.replicate n v >>= return . DenseVector
newDenseVectorByGen :: V.Storable a => IO a -> Int -> IO (DenseVector a)
newDenseVectorByGen g n = do
vals <- V.replicateM n g
return $ DenseVector vals
newDenseMatrix :: V.Storable a => Int
-> Int
-> IO (DenseMatrix a)
newDenseMatrix r c = DenseMatrix r c <$> V.new (r*c)
newDenseMatrixConst:: V.Storable a => Int -> Int -> a -> IO (DenseMatrix a)
newDenseMatrixConst r c v = V.replicate (r*c) v >>= return . DenseMatrix r c
newDenseMatrixCopy :: V.Storable a => DenseMatrix a -> IO (DenseMatrix a)
newDenseMatrixCopy (DenseMatrix r c v) = V.clone v >>= return . DenseMatrix r c
newDenseMatrixArray :: V.Storable a => Int
-> Int
-> Int
-> IO (DenseMatrixArray a)
newDenseMatrixArray n r c = DenseMatrixArray n r c <$> V.new (n*r*c)
denseMatrixArrayAt :: V.Storable a => DenseMatrixArray a -> Int -> DenseMatrix a
denseMatrixArrayAt (DenseMatrixArray n r c v) i =
assert (i >= 0 && i < n) $ let seg = r*c in DenseMatrix r c (V.unsafeSlice (i*seg) seg v)
denseMatrixArrayToVector :: V.Storable a => DenseMatrixArray a -> BV.Vector (DenseMatrix a)
denseMatrixArrayToVector (DenseMatrixArray n r c v) =
let seg = r*c in BV.fromList [DenseMatrix r c (V.unsafeSlice (i*seg) seg v) | i <- [0..n1]]
denseMatrixArrayFromVector :: V.Storable a => BV.Vector (DenseMatrix a) -> IO (DenseMatrixArray a)
denseMatrixArrayFromVector vm = do
let n = BV.length vm
DenseMatrix r c (V.MVector _ ptr0) = BV.head vm
DenseVector raw <- denseVectorConcat (BV.map m2v vm)
return $ DenseMatrixArray n r c raw
v2m r c (DenseVector v) = DenseMatrix r c v
m2v (DenseMatrix _ _ v) = DenseVector v
v2ma n r c (DenseVector v) = assert (V.length v == n*r*c) $ DenseMatrixArray n r c v
ma2v (DenseMatrixArray n r c v) = DenseVector v
denseVectorToVector :: V.Storable a => DenseVector a -> IO (BV.Vector a)
denseVectorToVector (DenseVector vs) = SV.unsafeFreeze vs >>= return . BV.convert
denseVectorConcat :: V.Storable a => BV.Vector (DenseVector a) -> IO (DenseVector a)
denseVectorConcat vs = do
let n = BV.length vs
DenseVector (V.MVector sz0 ptr0) = BV.head vs
cont <- newIORef True
size <- newIORef sz0
forM_ [0..n2] $ \i -> do
let DenseVector (V.MVector sz1 ptr1) = vs BV.! i
DenseVector (V.MVector sz2 ptr2) = vs BV.! (i+1)
modifyIORef cont (&& (V.getPtr ptr1 `advancePtr` sz1) == V.getPtr ptr2)
modifyIORef size (+ sz2)
cont <- readIORef cont
size <- readIORef size
if cont
then do
return $ DenseVector $ V.unsafeFromForeignPtr0 ptr0 size
else do
nvec@(DenseVector rv) <- newDenseVector size
go rv vs
return nvec
where
go vt vs =
if BV.null vs
then assert (V.length vt == 0) $ return ()
else do
let DenseVector src = BV.head vs
(v1, v2) = V.splitAt (V.length src) vt
V.unsafeCopy v1 src
go v2 (BV.tail vs)
denseVectorSplit :: V.Storable a => Int -> Int -> DenseVector a -> BV.Vector (DenseVector a)
denseVectorSplit n c (DenseVector v) = assert (V.length v > n * c) $
BV.map (\i -> DenseVector (V.unsafeSlice (i*c) c v)) $ BV.enumFromN 0 n
sliceM :: V.Storable a => DenseMatrix a -> (Int, Int) -> DenseVector a
sliceM (DenseMatrix r c d) (x,y) = assert (x>=0 && x<r && y>=0 && y<c) $ DenseVector v
where
v = V.unsafeDrop (x*c+y) d
dropV n (DenseVector v) = DenseVector (V.unsafeDrop n v)
copyV (DenseVector v1) (DenseVector v2) len =
assert (V.length v1 >= len && V.length v2 >= len) $
V.unsafeCopy (V.unsafeTake len v1) (V.unsafeTake len v2)
unsafeReadV :: V.Storable a => DenseVector a -> Int -> IO a
unsafeReadV (DenseVector v) i = V.unsafeRead v i
unsafeWriteV :: V.Storable a => DenseVector a -> Int -> a -> IO ()
unsafeWriteV (DenseVector v) i a = V.unsafeWrite v i a
unsafeReadM :: V.Storable a => DenseMatrix a -> (Int, Int) -> IO a
unsafeReadM (DenseMatrix r c v) (i,j) = assert (i < r && j < c) $ V.unsafeRead v (i*c+j)
unsafeWriteM :: V.Storable a => DenseMatrix a -> (Int, Int) -> a -> IO ()
unsafeWriteM (DenseMatrix r c v) (i,j) a = assert (i < r && j < c) $ V.unsafeWrite v (i*c+j) a
class Size a where
type Dim a
size :: a -> Dim a
instance V.Storable a => Size (DenseVector a) where
type Dim (DenseVector a) = Int
size (DenseVector v) = V.length v
instance V.Storable a => Size (DenseMatrix a) where
type Dim (DenseMatrix a) = (Int,Int)
size (DenseMatrix r c v) = assert (V.length v >= r * c) $ (r,c)
instance V.Storable a => Size (DenseMatrixArray a) where
type Dim (DenseMatrixArray a) = (Int,Int,Int)
size (DenseMatrixArray n r c v) = assert (V.length v >= n * r * c) $ (n,r,c)
infix 4 :<#, :#>, :<>, :##, :.*, :.+
infix 0 <<=, <<+
data Op :: (* -> *) -> * -> * where
(:<#) :: DenseVector a -> DenseMatrix a -> Op DenseVector a
(:#>) :: DenseMatrix a -> DenseVector a -> Op DenseVector a
(:<>) :: DenseMatrix a -> DenseMatrix a -> Op DenseMatrix a
(:##) :: DenseVector a -> DenseVector a -> Op DenseMatrix a
(:.*) :: c a -> c a -> Op c a
(:.+) :: c a -> c a -> Op c a
Scale :: a -> Op c a
Apply :: (SIMDPACK a -> SIMDPACK a) -> Op c a
ZipWith :: (SIMDPACK a -> SIMDPACK a -> SIMDPACK a) -> c a -> c a -> Op c a
Scale' :: a -> Op c a -> Op c a
UnsafeM2MA :: Op DenseMatrix a -> Op DenseMatrixArray a
class AssignTo c a where
(<<=) :: c a -> Op c a -> IO ()
(<<+) :: c a -> Op c a -> IO ()
instance (Numeric a, V.Storable a, SIMDable a) => AssignTo DenseVector a where
(DenseVector v) <<= (DenseVector x :<# DenseMatrix r c y) =
assert (V.length x == r && V.length v == c) $ gemv_helper Trans r c 1.0 y c x 0.0 v
(DenseVector v) <<= (DenseMatrix r c x :#> DenseVector y) =
assert (V.length y == c && V.length v == r) $ gemv_helper NoTrans r c 1.0 x c y 0.0 v
(DenseVector v) <<= (DenseVector x :.* DenseVector y) =
let sz = V.length v
in assert (sz == V.length x && sz == V.length y) $
hadamard times v x y
(DenseVector v) <<= (DenseVector x :.+ DenseVector y) =
let sz = V.length v
in assert (sz == V.length x && sz == V.length y) $
hadamard plus v x y
(DenseVector v) <<= Scale s =
V.unsafeWith v (\pv -> scal (V.length v) s pv 1)
(DenseVector v) <<= Apply f = foreach f v v
(DenseVector v) <<= ZipWith f (DenseVector x) (DenseVector y) = hadamard f v x y
(DenseVector v) <<= Scale' a (DenseMatrix r c x :#> DenseVector y) =
assert (V.length y == c && V.length v == r) $ gemv_helper NoTrans r c a x c y 0.0 v
_ <<= _ = error "Unsupported Op [Vector <<=]."
(DenseVector v) <<+ (DenseVector x :<# DenseMatrix r c y) =
assert (V.length x == r && V.length v == c) $ gemv_helper Trans r c 1.0 y c x 1.0 v
(DenseVector v) <<+ (DenseMatrix r c x :#> DenseVector y) =
assert (V.length y == c && V.length v == r) $ gemv_helper NoTrans r c 1.0 x c y 1.0 v
(DenseVector v) <<+ Scale' a (DenseMatrix r c x :#> DenseVector y) =
assert (V.length y == c && V.length v == r) $ gemv_helper NoTrans r c a x c y 1.0 v
_ <<+ _ = error "Unsupported Op [Vector <<+]."
instance (Numeric a, V.Storable a, SIMDable a) => AssignTo DenseMatrix a where
(DenseMatrix vr vc v) <<= (DenseMatrix xr xc x :<> DenseMatrix yr yc y) =
assert (xc == yc && vc == xr && vr == yr) $ do
gemm_helper Trans NoTrans xr yr xc 1.0 x xc y xc 0.0 v xr
(DenseMatrix vr vc v) <<= (DenseMatrix xr xc x :.* DenseMatrix yr yc y) =
assert (vr == xr && vr == yr && vc == xc && vc == yc) $ hadamard times v x y
(DenseMatrix vr vc v) <<= (DenseMatrix xr xc x :.+ DenseMatrix yr yc y) =
assert (vr == xr && vr == yr && vc == xc && vc == yc) $ hadamard plus v x y
(DenseMatrix r c v) <<= Scale s =
let sz = V.length v
in assert (sz == r * c) $
V.unsafeWith v (\pv -> scal sz s pv 1)
(DenseMatrix r c v) <<= Apply f = (DenseVector v) <<= Apply f
(DenseMatrix vr vc v) <<= Scale' a (DenseMatrix xr xc x :<> DenseMatrix yr yc y) =
assert (xc == yc && vc == xr && vr == yr) $ do
gemm_helper Trans NoTrans xr yr xc a x xc y xc 0.0 v xr
_ <<= _ = error "Unsupported Op [Matrix <<=]."
(DenseMatrix vr vc v) <<+ (DenseMatrix xr xc x :<> DenseMatrix yr yc y) =
assert (xc == yc && vc == xr && vr == yr) $ do
gemm_helper Trans NoTrans xr yr xc 1.0 x xc y xc 1.0 v xr
(DenseMatrix vr vc v) <<+ (DenseVector x :## DenseVector y) =
let m = V.length x
n = V.length y
in assert (m == vr && n == vc) $
V.unsafeWith v (\pv ->
V.unsafeWith x (\px ->
V.unsafeWith y (\py ->
geru RowMajor m n 1.0 px 1 py 1 pv n)))
(DenseMatrix vr vc v) <<+ Scale' a (DenseMatrix xr xc x :<> DenseMatrix yr yc y) =
assert (xc == yc && vc == xr && vr == yr) $ do
gemm_helper Trans NoTrans xr yr xc a x xc y xc 1.0 v xr
_ <<+ _ = error "Unsupported Op [Matrix <<+]."
instance (Numeric a, V.Storable a, SIMDable a) => AssignTo DenseMatrixArray a where
ma <<= UnsafeM2MA op = let ma2m (DenseMatrixArray n r c v) = DenseMatrix n (r*c) v
in (ma2m ma) <<= op
ma <<= Scale' r (UnsafeM2MA op) = ma <<= UnsafeM2MA (Scale' r op)
_ <<= _ = error "Unsupported Op [MatrixArray <<=]."
ma <<+ UnsafeM2MA op = let ma2m (DenseMatrixArray n r c v) = DenseMatrix n (r*c) v
in (ma2m ma) <<+ op
ma <<+ Scale' r (UnsafeM2MA op) = ma <<+ UnsafeM2MA (Scale' r op)
_ <<+ _ = error "Unsupported Op [MatrixArray <<+]."
sumElements :: (V.Storable a, Num a) => DenseMatrix a -> IO a
sumElements (DenseMatrix r c v) = go v (r*c) 0
where
go v 0 !s = return s
go v !n !s = do a <- V.unsafeRead v 0
go (V.unsafeTail v) (n1) (a+s)
corr2 :: (V.Storable a, Numeric a)
=> Int
-> BV.Vector (DenseMatrix a)
-> DenseMatrix a
-> (Op DenseMatrixArray a -> IO b)
-> IO b
corr2 p ks m fun = do
let k0 = BV.head ks
(kr,kc) = size k0
(mr,mc) = size m
u = mr kr + 2*p + 1
v = mc kc + 2*p + 1
zpd <- zero m mr mc p
wrk <- newDenseMatrix (u*v) (kr*kc)
fill wrk zpd u v kr kc
DenseMatrixArray n r c v <- denseMatrixArrayFromVector ks
fun $ UnsafeM2MA $ wrk :<> DenseMatrix n (r*c) v
conv2 :: (V.Storable a, Numeric a)
=> Int
-> BV.Vector (DenseMatrix a)
-> DenseMatrix a
-> (Op DenseMatrixArray a -> IO b)
-> IO b
conv2 p ks m fun = do
let k0 = BV.head ks
(kr,kc) = size k0
(mr,mc) = size m
u = mr kr + 2*p + 1
v = mc kc + 2*p + 1
zpd <- zero m mr mc p
wrk <- newDenseMatrix (u*v) (kr*kc)
fill wrk zpd u v kr kc
let nk = BV.length ks
knl@(DenseMatrixArray _ _ _ v) <- newDenseMatrixArray nk kr kc
forM_ [0..nk1] $ \i -> do
let DenseMatrix _ _ d = denseMatrixArrayAt knl i
let DenseMatrix _ _ s = ks BV.! (nk1i)
V.unsafeCopy d s
reverseV v
fun $ UnsafeM2MA $ wrk :<> DenseMatrix nk (kr*kc) v
where
reverseV v = let e = V.length v
m = e `div` 2
in forM_ [0..m] (\i -> V.unsafeSwap v i (e1i))
zero m mr mc p = do
zpd <- newDenseMatrix (mr+2*p) (mc+2*p)
forM_ [0..mr1] $ \i -> do
let t = sliceM zpd (p+i, p)
s = sliceM m ( i, 0)
copyV t s mc
return zpd
fill wrk@(DenseMatrix _ _ vwrk) m u v kr kc = do
refv <- newIORef (DenseVector vwrk)
forM_ [0..u1] $ \i -> do
forM_ [0..v1] $ \j -> do
forM_ [0..kr1] $ \k -> do
t <- readIORef refv
let s = sliceM m (i+k, j)
copyV t s kc
writeIORef refv (dropV kc t)
pool :: Int -> DenseMatrix Float -> IO (DenseVector Int, DenseMatrix Float)
pool 1 mat = do
let (r,c) = size mat
vi <- newDenseVector (r*c)
return (vi, mat)
pool stride mat = do
mxi <- newDenseVector (r'*c')
mxv <- newDenseMatrix r' c'
forM_ [0..r'1] $ \i -> do
forM_ [0..c'1] $ \j -> do
(n,v) <- unsafeMaxIndEle mat (i*stride) (j*stride) stride stride
unsafeWriteV mxi (i*c'+j) n
unsafeWriteM mxv (i,j) v
return (mxi,mxv)
where
(r,c) = size mat
r' = r `div` stride
c' = c `div` stride
unsafeMaxIndEle mm x y r c = do
mp <- newIORef 0
mv <- newIORef (10000.0)
forM_ [0..r1] $ \ i -> do
forM_ [0..c1] $ \ j -> do
v1 <- unsafeReadM mm (x+i, y+j)
v0 <- readIORef mv
when (v1 > v0) $ do
writeIORef mv v1
writeIORef mp (i*stride+j)
p <- readIORef mp
v <- readIORef mv
return (p, v)
unpool :: Int -> DenseVector Int -> DenseMatrix Float -> IO (DenseMatrix Float)
unpool stride idx mat = do
mat' <- newDenseMatrix r' c'
forM_ [0..r1] $ \i -> do
forM_ [0..c1] $ \j -> do
pos <- unsafeReadV idx (i*c+j)
val <- unsafeReadM mat (i,j)
let (oi,oj) = pos `divMod` 2
unsafeWriteM mat' (i*stride+oi, j*stride+oj) val
return mat'
where
(r,c) = size mat
(r',c') = (r*stride, c*stride)
transpose :: V.Storable a => BV.Vector (DenseMatrixArray a) -> IO (BV.Vector (BV.Vector (DenseMatrix a)))
transpose vma = do
let DenseMatrixArray n _ _ _ = BV.head vma
!vv = BV.map (\i -> BV.map (`denseMatrixArrayAt` i) vma) $ BV.enumFromN 0 n
return vv
gemv_helper :: Numeric a
=> Transpose
-> Int -> Int
-> a
-> V.IOVector a
-> Int
-> V.IOVector a
-> a
-> V.IOVector a -> IO ()
gemv_helper trans row col alpha x lda y beta v =
V.unsafeWith x (\px ->
V.unsafeWith y (\py ->
V.unsafeWith v (\pv ->
gemv RowMajor trans row col alpha px lda py 1 beta pv 1)))
gemm_helper :: Numeric a
=> Transpose
-> Transpose
-> Int -> Int -> Int
-> a
-> V.IOVector a
-> Int
-> V.IOVector a
-> Int
-> a
-> V.IOVector a
-> Int
-> IO ()
gemm_helper transA transB rowA colB colA alpha x xlda y ylda beta v vlda =
V.unsafeWith x (\px ->
V.unsafeWith y (\py ->
V.unsafeWith v (\pv -> do
gemm ColMajor transA transB rowA colB colA alpha px xlda py ylda beta pv vlda)))