{-# LANGUAGE CPP #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE DeriveDataTypeable #-} module Main where import qualified Data.Number.ER.Real as AERN import Data.Number.ER.BasicTypes import Data.Number.ER.Misc import Data.Maybe import qualified Data.List as List import qualified Data.Map as Map import qualified Data.Array.IArray as IAr import qualified Data.Array.MArray as MAr import qualified Data.Array.ST as STAr import qualified Data.Ix as Ix import qualified Data.Array.Base as BAr import Control.Monad.ST import GHC.Arr #ifdef USE_MPFR type B = AERN.BAP -- use pure Haskell floats --type B = AERN.BMPFR -- use MPFR floats #else type B = AERN.BAP -- use pure Haskell floats #endif type RA = AERN.RA B type IRA = AERN.IRA B testMatrixN = 100 incrementGran = (+) 50 -- Hilbert 100x100 matrix: addOneDiag = False targetPrec = 167 -- approx 50 decimal digits after the point initialGran = 2050 -- 100x100 --initialGran = 2388 -- 100x100 Norbert's --initialGran = 750 -- 50x50 --initialGran = 300 -- 10x10 --targetPrec = 34 -- approx 10 decimal digits after the point --initialGran = 1350 --initialGran = 50 -- 50x50 -- Hilbert matrix + 1: --addOneDiag = True --targetPrec = 167 -- approx 50 decimal digits after the point --initialGran = 200 --targetPrec = 34 -- approx 10 decimal digits after the point --initialGran = 50 main = do AERN.initialiseBaseArithmetic (0 :: RA) putStrLn $ "Inverting the " ++ show n ++ "x" ++ show n ++ " Hilbert matrix " ++ "with target binary precision " ++ show targetPrec ++ "..." -- putStrLn $ -- "sorted matrix elements = \n" ++ (unlines $ map show elemsSortedByPrec) putStrLn $ "sum of all elements in inverted matrix = " ++ show (sum elems) -- putStrLn $ show (Matrix n n rarr) where n = testMatrixN elems = IAr.elems rarr elemsSortedByPrec = List.sortBy comparePrec elems where comparePrec a b = compare aPrecLO bPrecLO where aPrecLO = fst $ AERN.bounds $ aHI - aLO (aLO, aHI) = AERN.bounds a bPrecLO = fst $ AERN.bounds $ bHI - bLO (bLO, bHI) = AERN.bounds b rarr = STAr.runSTArray $ do mInv@(Matrix _ _ rowsInv) <- invert testMatrix -- m <- testMatrix initialGran -- mUnit@(Matrix _ _ rowsUnit) <- multM m mInv return rowsInv testMatrix :: Granularity -> ST s (STMatrix s IRA) testMatrix gran = do marr <- MAr.newArray ((1,1),(n,n)) 0 mapM (updateCell marr) assocsGran return $ Matrix n n marr where assocsGran = map (mapSnd $ AERN.setMinGranularityOuter gran) assocs assocs = -- assocsMini assocsHilbert gran n assocsMini = [((1,1),1), ((1,2),3), ((2,1),2), ((2,2),0) ] n = testMatrixN updateCell marr (ix, el) = do unsafeMatrixWrite marr n ix el assocsHilbert gran n = [((i,j), coeff i j)| i <- [1..n], j <- [1..n]] where coeff i j | addOneDiag && i == j = 1 + oneOverIplusJ | otherwise = oneOverIplusJ where oneOverIplusJ = recip $ (AERN.setMinGranularityOuter gran $ iRA + jRA + 1) iRA = fromInteger $ toInteger i jRA = fromInteger $ toInteger j --invert :: -- Precision -> -- () -> invert getMatrix = do gaussElim getMatrixI where n = testMatrixN getMatrixI gran = do m <- getMatrix gran mI <- addIdentity m return mI gaussElim getMatrix = elimWithMinGran initialGran where elimWithMinGran workingGran = do mI@(Matrix colN rowN _) <- getMatrix workingGran idPerm <- MAr.newListArray (1,rowN) [1..rowN] elimAtRow mI 1 idPerm where elimAtRow mI@(Matrix colN rowN mIarr) i perm = do success <- ensureNonZeroDiag -- make sure (i,i) is non-zero by permuting case success of False -> -- failed - all elements contain 0 -> try larger granularity unsafePrint ("failed to divide at granularity " ++ show workingGran) $ elimWithMinGran (incrementGran workingGran) True -> do normaliseRow eliminateColumn case i == rowN of True -> do mInv <- permuteRowsDropCols perm testMatrixN mI mPrec <- getMatrixPrecision mInv case mPrec >= targetPrec of False -> -- resulting precision insufficient unsafePrint ("insufficient precision " ++ show mPrec ++ " at granularity " ++ show workingGran) $ elimWithMinGran (incrementGran workingGran) True -> unsafePrint ("precision " ++ show mPrec ++ " succeeded at granularity " ++ show workingGran) return mInv False -> elimAtRow mI (i+1) perm where ensureNonZeroDiag = do maybeNonZeroIx <- findNonZeroRow case maybeNonZeroIx of Nothing -> return False Just ii -> do case ii > 0 of True -> swap i (i + ii) perm False -> return () return True findNonZeroRow = do elems <- mapM getElemPerm [(i,rowIx) | rowIx <- [i..rowN]] return $ List.findIndex (\e -> not $ 0 `AERN.refines` e) elems getElemPerm (colIx,rowIx) = do rowIxPerm <- unsafePermRead perm rowIx unsafeMatrixRead mIarr rowN (colIx, rowIxPerm) normaliseRow = do rowIxPerm <- unsafePermRead perm i e <- unsafeMatrixRead mIarr rowN (i, rowIxPerm) unsafeMatrixWrite mIarr rowN (i, rowIxPerm) 1 mapM (divideCellBy e rowIxPerm) [(i+1)..colN] divideCellBy e rowIxPerm colIx = do e2 <- unsafeMatrixRead mIarr rowN (colIx, rowIxPerm) unsafeMatrixWrite mIarr rowN (colIx, rowIxPerm) (e2/e) eliminateColumn = do iRowPerm <- unsafePermRead perm i mapM (eliminateColumnRow iRowPerm) $ [1..(i-1)] ++ [(i+1)..rowN] eliminateColumnRow iRowPerm rowIx = do rowIxPerm <- unsafePermRead perm rowIx c <- unsafeMatrixRead mIarr rowN (i, rowIxPerm) -- remember old element for scaling i'th row unsafeMatrixWrite mIarr rowN (i,rowIxPerm) 0 -- at column i we set 0 mapM (eliminateColumnRowColumn iRowPerm rowIxPerm c) [(i+1)..colN] eliminateColumnRowColumn iRowPerm rowIxPerm c colIx = do ei <- unsafeMatrixRead mIarr rowN (colIx, iRowPerm) -- at i'th row er <- unsafeMatrixRead mIarr rowN (colIx, rowIxPerm) -- at current row unsafeMatrixWrite mIarr rowN (colIx, rowIxPerm) (er - c * ei) -- eliminate by i'th row swap :: Int -> Int -> (STAr.STUArray s Int Int) -> ST s () swap i1 i2 perm = do a1 <- unsafePermRead perm i1 a2 <- unsafePermRead perm i2 unsafePermWrite perm i1 a2 unsafePermWrite perm i2 a1 unsafePermWrite permArr i e = do BAr.unsafeWrite permArr (i - 1) e unsafePermRead permArr i = do BAr.unsafeRead permArr (i - 1) addIdentity :: (STMatrix s IRA) -> ST s (STMatrix s IRA) addIdentity (Matrix colN rowN marr) = do -- (_, (colN,rowN)) <- MAr.getBounds marr mElems <- MAr.getElems marr mIarr <- MAr.newListArray ((1,1),(colN+rowN,rowN)) $ mElems ++ (idElems rowN) return $ Matrix (colN + rowN) rowN mIarr where idElems m = 1 : (concat $ replicate (m-1) $ (replicate m 0) ++ [1]) data Matrix marr el = Matrix { mxRowN :: Int, mxColN :: Int, mxRows :: marr (ColIx,RowIx) el } type ColIx = Int type RowIx = Int type IMatrix el = Matrix Array el type STMatrix s el = Matrix (STArray s) el instance (IAr.IArray marr el,-- IAr.IArray marr (marr Int el), Show el) => Show (Matrix marr el) where show (Matrix colN rowN rows) = "\nMatrix:\n" ++ (concat $ map showCol [1..colN]) where -- (_,(colN,rowN)) = IAr.bounds rows showCol colIx = unlines $ map showCell [(colIx, rowIx) | rowIx <- [1..rowN]] showCell ix@(colIx, rowIx) = (show ix) ++ (replicate colIx '.') ++ (show $ (IAr.!) rows ix) getMatrixPrecision (Matrix _ _ marr) = do elems <- MAr.getElems marr return $ foldl1 min $ map AERN.getPrecision elems unsafeMatrixWrite marr rowN (i,j) e = do BAr.unsafeWrite marr (rowN*(i-1) + j-1) e -- MAr.writeArray marr (i,j) e unsafeMatrixRead marr rowN (i,j) = do BAr.unsafeRead marr (rowN*(i-1) + j-1) -- MAr.readArray marr (i,j) permuteRowsDropCols :: (STAr.STUArray s Int Int) -> Int {-^ drop this many first columns -} -> (STMatrix s IRA) -> ST s (STMatrix s IRA) permuteRowsDropCols perm dropN (Matrix colN rowN marr) = do -- (_, (colN,rowN)) <- MAr.getBounds marr (_, permN) <- MAr.getBounds perm rarr <- MAr.newArray ((1,1),(colN - dropN, permN)) 0 mapM (copyElem marr rarr rowN) [(colIx, rowIx) | colIx <- [1..colN - dropN], rowIx <- [1..permN]] return (Matrix (colN - dropN) permN rarr) where copyElem marr rarr rowN (colIx, rowIx) = do permRowIx <- unsafePermRead perm rowIx e <- unsafeMatrixRead marr rowN (colIx + dropN, permRowIx) unsafeMatrixWrite rarr rowN (colIx, rowIx) e addM m1 m2 | mxColN m1 == mxColN m2 && mxRowN m1 == mxRowN m2 = do marr <- MAr.newArray ((1,1),(colN, rowN)) 0 mapM (addCell marr) [(c,r) | c <- [1..colN], r <- [1..rowN]] return (Matrix colN rowN marr) | otherwise = error "Matrix: addM mismatch" where colN = mxColN m1 rowN = mxRowN m1 marr1 = mxRows m1 marr2 = mxRows m2 addCell marr (colIx, rowIx) = do elem1 <- unsafeMatrixRead marr1 rowN (colIx, rowIx) elem2 <- unsafeMatrixRead marr2 rowN (colIx, rowIx) unsafeMatrixWrite marr rowN (colIx, rowIx) (elem1 + elem2) multM m1 m2 | colN1 == rowN2 = do marr <- MAr.newArray ((1,1),(colN, rowN)) 0 mapM (multCell marr) [(c,r) | c <- [1..colN], r <- [1..rowN]] return (Matrix colN rowN marr) | otherwise = error "Matrix: multM mismatch" where colN1 = mxColN m1 rowN1 = mxRowN m1 colN2 = mxColN m2 rowN2 = mxRowN m2 colN = colN2 rowN = rowN1 marr1 = mxRows m1 marr2 = mxRows m2 multCell marr (colIx, rowIx) = do elems1 <- mapM (getCell1 rowIx) [1..colN1] elems2 <- mapM (getCell2 colIx) [1..rowN2] unsafeMatrixWrite marr rowN (colIx, rowIx) (sum $ zipWith (*) elems1 elems2) getCell1 rowIx colIx = do unsafeMatrixRead marr1 rowN1 (colIx, rowIx) getCell2 rowIx colIx = do unsafeMatrixRead marr2 rowN2 (colIx, rowIx)