Ticket #1887: mat_mult_ndp.hs

File mat_mult_ndp.hs, 1.9 KB (added by mrd, 5 years ago)

source code

Line 
1import Control.Monad
2import Control.Parallel
3import Control.Parallel.Strategies
4import Data.Array.Unboxed
5import Data.Char
6import qualified Data.ByteString.Char8 as SS
7import Data.List
8import System.Environment
9import System.IO
10import Data.Array.Parallel.Unlifted
11import Data.Array.Parallel.Unlifted.Distributed
12
13main = do
14  [_, pstr, infile, outfile] <- getArgs
15  let p = read pstr :: Int
16  setGang p
17  iH <- openFile infile ReadMode
18  cs <- SS.hGetContents iH
19  let Just (n, cs') = readInt cs
20      (vals1,vals2) = splitAt (n * n) (unfoldr readDouble cs')
21      inmat1        = map toU . fst $ segm n vals1
22      inmat2        = map toU . snd $ segm n vals2
23      outmat        = inmat1 `par` inmat2 `par` 
24                        ndp_matmult n inmat1 inmat2
25  hClose iH 
26  oH <- openFile outfile WriteMode
27  hPrint oH n
28  forM_ [0 .. n-1] $ \ i -> do
29    forM_ [0 .. n-1] $ \ j -> do
30      hPutStr oH $ show (outmat!(i,j))
31      hPutStr oH " "
32    hPutStr oH "\n"
33  hClose oH   
34  return ()
35
36ndp_matmult :: Int -> [UArr Double] -> [UArr Double] -> UArray (Int, Int) Double
37ndp_matmult n inmat1 inmat2 = listArray ((0, 0), (n - 1, n - 1))
38--                              . map dotP $ sequence [inmat1, inmat2]
39                              $  parFlatMap (parList rwhnf)
40                                   (\ x -> map (dotP x) inmat2) inmat1
41
42
43-- dotP [a, b] = sumU (zipWithU (*) a b) :: Double
44dotP a b = sumU (zipWithU (*) a b) :: Double
45
46
47readInt cs = SS.readInt cs >>= (\ (i, cs') -> return (i, snd (SS.span isSpace cs')))
48readDouble cs = do
49  let (s, cs') = SS.break isSpace cs
50  case reads (SS.unpack s) of
51    []       -> Nothing
52    [(d, _)] -> Just (d :: Double, snd $ SS.span isSpace cs')
53
54segm n [] = ([], replicate n [])
55segm n ds = (row:rows, cols')
56  where
57    (row, ds') = splitAt n ds
58    (rows, cols) = segm n ds'
59    cols' = zipWith (:) row cols