module Numeric.FFT.Execute ( execute ) where import Prelude hiding (concatMap, foldr, length, map, mapM_, null, reverse, sum, zip, zipWith) import qualified Prelude as P import Control.Monad (when) import qualified Control.Monad as CM import Control.Monad.ST import Control.Monad.Primitive (PrimMonad) import Data.Complex import Data.STRef import qualified Data.Vector as V import Data.Vector.Unboxed import qualified Data.Vector.Unboxed.Mutable as MV import qualified Data.IntMap.Strict as IM import qualified Data.Map as M import Numeric.FFT.Types import Numeric.FFT.Utils import Numeric.FFT.Special -- | Main FFT plan execution driver. execute :: Plan -> Direction -> VCD -> VCD execute (Plan dlinfo perm base) dir h = if n == 1 then h else if V.null dlinfo then runST $ do mhin <- case perm of Nothing -> thaw h Just p -> unsafeThaw $ backpermute h p mhout <- MV.replicate n 0 applyBase base sign mhin mhout when (dir == Inverse) $ do let s = 1.0 / fromIntegral n :+ 0 CM.forM_ [0..n-1] $ \i -> do x <- MV.unsafeRead mhout i MV.unsafeWrite mhout i $ s * x unsafeFreeze mhout else fullfft where n = length h -- Input vector length. bsize = baseSize base -- Size of base transform. -- Root of unity sign. sign = case dir of Forward -> 1 Inverse -> -1 -- Apply Danielson-Lanczos steps and base transform to digit -- reversal ordered input vector. fullfft = runST $ do mhin <- case perm of Nothing -> thaw h Just p -> unsafeThaw $ backpermute h p mhtmp <- MV.replicate n 0 multBase mhin mhtmp mhr <- newSTRef (mhtmp, mhin) V.forM_ dlinfo $ \dlstep -> do (mh0, mh1) <- readSTRef mhr dl sign dlstep mh0 mh1 writeSTRef mhr (mh1, mh0) mhs <- readSTRef mhr let vout = fst mhs when (dir == Inverse) $ do let s = 1.0 / fromIntegral n :+ 0 CM.forM_ [0..n-1] $ \i -> do x <- MV.unsafeRead vout i MV.unsafeWrite vout i $ s * x unsafeFreeze vout -- Multiple base transform application for "bottom" of algorithm. multBase :: MVCD s -> MVCD s -> ST s () multBase xmin xmout = V.zipWithM_ (applyBase base sign) (slicemvecs bsize xmin) (slicemvecs bsize xmout) -- | Monadic FFT plan execution driver -- used by Rader's algorithm -- for convolutions. executeM :: Plan -> Direction -> MVCD s -> MVCD s -> ST s () executeM (Plan dlinfo perm base) dir hin hout = if n == 1 then MV.copy hout hin else do htmp <- MV.replicate n 0 -- Input permutation. case perm of Nothing -> MV.copy htmp hin Just p -> backpermuteM n p hin htmp -- Apply Danielson-Lanczos steps and base transform to digit -- reversal ordered input vector. multBase htmp hout mhr <- newSTRef (hout, htmp) V.forM_ dlinfo $ \dlstep -> do (mh0, mh1) <- readSTRef mhr dl sign dlstep mh0 mh1 writeSTRef mhr (mh1, mh0) when (odd $ V.length dlinfo) $ MV.copy hout htmp -- Output scaling for inverse transform. when (dir == Inverse) $ do let s = 1.0 / fromIntegral n :+ 0 forM_ (enumFromN 0 n) $ \i -> do x <- MV.unsafeRead hout i MV.unsafeWrite hout i $ s * x where n = MV.length hin -- Input vector length. bsize = baseSize base -- Size of base transform. -- Root of unity sign. sign = case dir of Forward -> 1 Inverse -> -1 -- Multiple base transform application for "bottom" of algorithm. multBase :: MVCD s -> MVCD s -> ST s () multBase xmin xmout = V.zipWithM_ (applyBase base sign) (slicemvecs bsize xmin) (slicemvecs bsize xmout) -- | Single Danielson-Lanczos step: process all duplicates and -- concatenate into a single vector. dl :: Int -> (Int, Int, VVVCD, VVVCD) -> MVCD s -> MVCD s -> ST s () dl sign (wfac, split, dmatp, dmatm) mhin mhout = V.zipWithM_ doone (slicemvecs wfac mhin) (slicemvecs wfac mhout) where -- Twiddled diagonal entries in row r, column c (both -- zero-indexed), where each row and column if a wfac x wfac -- matrix. dmat = if sign == 1 then dmatp else dmatm d r c = (dmat V.! r) V.! c -- Size of each diagonal sub-matrix. ns = wfac `div` split -- Index vectors. nsidxs = enumFromN 0 ns splitidxs = enumFromN 1 (split-1) -- Process one duplicate by processing all rows and writing the -- results into a single output vector. doone :: MVCD s -> MVCD s -> ST s () doone vin vout = do let vs = (slicemvecs ns vin, slicemvecs ns vout) mapM_ (single vs) $ enumFromN 0 split where -- Multiply a single block by its appropriate diagonal -- elements and accumulate the result. mult :: VMVCD s -> MVCD s -> Int -> Bool -> Int -> ST s () mult vins vo r first c = do let vi = vins V.! c dvals = d r c forM_ nsidxs $ \i -> do xi <- MV.unsafeRead vi i xo <- if first then return 0 else MV.unsafeRead vo i MV.unsafeWrite vo i (xo + xi * dvals ! i) -- Multiply all blocks by the corresponding diagonal -- elements in a single row. single :: (VMVCD s, VMVCD s) -> Int -> ST s () single (vis, vos) r = do mult vis (vos V.! r) r True 0 mapM_ (mult vis (vos V.! r) r False) splitidxs -- single (vis, vos) r = -- let m = mult vis (vos V.! r) r -- in do -- m True 0 -- mapM_ (m False) splitidxs -- | Apply a base transform to a single vector. applyBase :: BaseTransform -> Int -> MVCD s -> MVCD s -> ST s () -- Simple DFT algorithm. applyBase (DFTBase sz wsfwd wsinv) sign mhin mhout = do h <- freeze mhin forM_ (enumFromN 0 sz) $ \i -> MV.unsafeWrite mhout i (doone h i) where ws = if sign == 1 then wsfwd else wsinv doone h i = sum $ zipWith (*) h $ generate sz (\k -> ws ! (i * k `mod` sz)) -- Special hard-coded cases. applyBase (SpecialBase sz) sign mhin mhout = case IM.lookup sz specialBases of Just f -> f sign mhin mhout Nothing -> error "invalid problem size for SpecialBase" -- Rader prime-length FFT. applyBase (RaderBase sz outperm bfwd binv csz cplan) sign mhin mhout = do -- Padding size. let pad = csz - (sz - 1) -- Permuted input vector padded to next greater power of two size -- for fast convolution. apad <- MV.replicate csz 0 forM_ (enumFromN 0 csz) $ \i -> do val <- if i == 0 then MV.unsafeRead mhin 1 else if i > pad then MV.unsafeRead mhin $ i - pad + 1 else return 0 MV.unsafeWrite apad i val -- FFT-based convolution calculation. convtmp <- MV.replicate csz 0 executeM cplan Forward apad convtmp let bmult = if sign == 1 then bfwd else binv forM_ (enumFromN 0 csz) $ \i -> do x <- MV.unsafeRead convtmp i MV.unsafeWrite convtmp i $ x * (bmult ! i) executeM cplan Inverse convtmp apad conv <- unsafeFreeze apad -- Input vector sum. sumhref <- newSTRef 0 forM_ (enumFromN 0 sz) $ \i -> do val <- MV.unsafeRead mhin i modifySTRef sumhref (+ val) sumh <- readSTRef sumhref -- Write output based on output generator index ordering. h0 <- MV.unsafeRead mhin 0 forM_ (enumFromN 0 sz) $ \i -> do let (idx, val) = case i of 0 -> (0, sumh) _ -> (outperm ! (i - 1), h0 + conv ! (i - 1)) MV.unsafeWrite mhout idx val