module Main where import qualified Data.Array.Accelerate.CUBLAS.Level3.Batched as Batched import qualified Data.Array.Accelerate.Arithmetic.LinearAlgebra as ALinAlg import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp import Data.Array.Accelerate.Utility.Lift.Exp (expr) import Data.Array.Accelerate (DIM1, Z(Z), (:.)((:.)), (!), (?), (==*)) import qualified Data.Array.Accelerate.CUDA as AC import qualified Data.Array.Accelerate as A import qualified Foreign.CUDA.Cublas as Cublas import System.Random (randomRs, mkStdGen) factorMatrices :: (ALinAlg.Matrix DIM1 Double, ALinAlg.Matrix DIM1 Double) factorMatrices = let numMats = 100 f = Exp.modify (expr :. expr :. expr :. expr) $ \(_z :. i :. j :. k) -> A.fromIntegral (i+j+k) in (A.generate (A.constant $ Z :. numMats :. 3 :. 4) f, A.generate (A.constant $ Z :. numMats :. 4 :. 2) f) mainMul :: Cublas.Handle -> IO () mainMul handle = do print factorMatrices print $ AC.run $ case factorMatrices of (a,b) -> Batched.mul handle 1 a b luMatrices :: (ALinAlg.Matrix Z Double, ALinAlg.Matrix Z Double) luMatrices = (A.use $ A.fromList (Z :. 4 :. 4) $ 2 : 0 : 0 : 0 : 1 : 3 : 0 : 0 : 0 : 1 : 4 : 0 : 0 : 0 : 1 : 5 : [], A.use $ A.fromList (Z :. 4 :. 4) $ 0 : 1 : 1 : 0 : 0 : 0 : 1 : 1 : 0 : 0 : 0 : 1 : 1 : 1 : 0 : 0 : []) permMatrix :: ALinAlg.Matrix Z Double permMatrix = A.use $ A.fromList (Z :. 4 :. 4) $ 0 : 2 : 0 : 0 : 0 : 0 : 0 : 4 : 0 : 0 : 3 : 0 : 1 : 0 : 0 : 0 : [] rhsMatrix :: ALinAlg.Matrix Z Double rhsMatrix = A.use $ A.fromList (Z :. 4 :. 2) $ 2 : 5 : 8 : 6 : 8 : 3 : 7 : 6 : [] append :: (A.Elt a) => ALinAlg.Matrix Z a -> ALinAlg.Matrix Z a -> ALinAlg.Matrix DIM1 a append x y = let (_z:.m:.n) = Exp.unlift (expr:.expr:.expr) $ A.intersect (A.shape x) (A.shape y) in A.generate (A.lift $ Z :. (2::Int) :. m :. n) $ Exp.modify (expr :. expr :. expr :. expr) $ \(_z :. k :. i :. j) -> let ix = A.index2 i j in k==*0 ? (x!ix, y!ix) mainLU :: Cublas.Handle -> IO () mainLU handle = do print luMatrices let mat = append permMatrix $ case luMatrices of (a,b) -> Batched.mul handle 1 a b lu = Batched.lu handle mat print $ AC.run $ Batched.luSolve handle lu $ Batched.mul handle 1 mat $ A.replicate (A.lift $ Z :. (2::Int) :. A.All :. A.All) rhsMatrix mainInv :: Cublas.Handle -> IO () mainInv handle = do let dim = 4 mat = A.fromList (Z:.3:.dim:.dim :: A.DIM3) $ randomRs (-1,1::Float) $ mkStdGen 42 test a = let (inv, info) = Batched.inv handle a in A.lift (Batched.mul handle 1 a inv, info) print $ AC.run1 test mat main :: IO () main = do handle <- Cublas.create mainMul handle mainLU handle mainInv handle