import Control.Monad
import Control.Parallel
import Control.Parallel.Strategies
import Data.Array.Unboxed
import Data.Char
import qualified Data.ByteString.Char8 as SS
import Data.List
import System.Environment
import System.IO
import Data.Array.Parallel.Unlifted
import Data.Array.Parallel.Unlifted.Distributed

main = do
  [_, pstr, infile, outfile] <- getArgs
  let p = read pstr :: Int
  setGang p
  iH <- openFile infile ReadMode
  cs <- SS.hGetContents iH
  let Just (n, cs') = readInt cs
      (vals1,vals2) = splitAt (n * n) (unfoldr readDouble cs')
      inmat1        = map toU . fst $ segm n vals1
      inmat2        = map toU . snd $ segm n vals2
      outmat        = inmat1 `par` inmat2 `par` 
                        ndp_matmult n inmat1 inmat2
  hClose iH  
  oH <- openFile outfile WriteMode
  hPrint oH n
  forM_ [0 .. n-1] $ \ i -> do
    forM_ [0 .. n-1] $ \ j -> do
      hPutStr oH $ show (outmat!(i,j))
      hPutStr oH " "
    hPutStr oH "\n"
  hClose oH    
  return ()

ndp_matmult :: Int -> [UArr Double] -> [UArr Double] -> UArray (Int, Int) Double
ndp_matmult n inmat1 inmat2 = listArray ((0, 0), (n - 1, n - 1))
--                              . map dotP $ sequence [inmat1, inmat2]
                              $  parFlatMap (parList rwhnf)
                                   (\ x -> map (dotP x) inmat2) inmat1


-- dotP [a, b] = sumU (zipWithU (*) a b) :: Double
dotP a b = sumU (zipWithU (*) a b) :: Double


readInt cs = SS.readInt cs >>= (\ (i, cs') -> return (i, snd (SS.span isSpace cs')))
readDouble cs = do
  let (s, cs') = SS.break isSpace cs
  case reads (SS.unpack s) of
    []       -> Nothing
    [(d, _)] -> Just (d :: Double, snd $ SS.span isSpace cs')

segm n [] = ([], replicate n [])
segm n ds = (row:rows, cols')
  where
    (row, ds') = splitAt n ds
    (rows, cols) = segm n ds'
    cols' = zipWith (:) row cols

