import qualified Prelude import Feldspar import Feldspar.Vector import Feldspar.Compiler import Feldspar.Matrix -- This code accompanies our lecture notes -- "Feldspar: Application and Implementation", in CEFP, -- Springer LNCS 7241, 2012. -- The code is documented in the notes. square :: Data WordN -> Data WordN square x = x*x f :: Data Int32 -> Data Int32 f i = (testBit i 0) ? (2*i, i) arr1n :: Data WordN -> Data [WordN] arr1n n = parallel n (\i -> (i+1)) squareEach :: Data [WordN] -> Data [WordN] squareEach as = parallel (getLength as) (\i -> square (getIx as i)) sfac :: Data WordN -> Data [WordN] sfac n = sequential n 1 g where g ix st = (j,j) where j = (ix + 1) * st fib :: Data Index -> Data Index fib n = fst $ forLoop n (1,1) $ \i (a,b) -> (b,a+b) intLog :: Data WordN -> Data WordN intLog n = fst $ whileLoop (0,n) (\(_,b) -> (b > 1)) (\ (a,b) -> (a+1, b `div` 2)) tw :: Data WordN -> Data WordN -> Data (Complex Float) tw n k = exp (-2 * pi * iunit * i2n k / i2n n) tws n = indexed n (tw n) squares :: Data WordN -> Vector1 WordN squares n = map square (1...n) flipBit :: Data Index -> Data Index -> Data Index flipBit i k = i `xor` (bit k) flips :: Data WordN -> Vector1 WordN -> Vector1 WordN flips k = map (\e -> flipBit e k) sumSqVn :: Data WordN -> Data WordN sumSqVn n = fold (+) 0 $ map square (1...n) -- Transforms -- DFT dft :: Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) dft xs = indexed n (\k -> sum (indexed n (\j -> xs!j * tw n (j*k)))) where n = length xs -- FFT premap :: (Data Index -> Data Index) -> Vector a -> Vector a premap f (Indexed l ixf Empty) = indexed l (ixf . f) revp :: (Bits a) => Data Index -> Vector1 a -> Vector1 a revp k = premap (`xor` (2^k - 1)) bfly :: Data Index -> Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) bfly k as = indexed l ixf where l = length as ixf i = (testBit i k) ? (b-a, a+b) where a = as ! i b = as ! (flipBit i k) -- Recursive -- works on sub-arrays of length 2^n fftr0 :: Index -> Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) fftr0 0 = id fftr0 n = fftr0 n' . twids0 vn' . bfly vn' where n' = n - 1 vn' = value n' -- Needs bit reversal on the output fft0 :: Index -> Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) fft0 n = bitRev (value n) . fftr0 n oneBitsN :: Data Index -> Data Index oneBitsN k = complement (shiftLU (complement 0) k) -- bit reversal bitr :: Data Index -> Data Index -> Data Index bitr n a = let mask = (oneBitsN n) in (complement mask .&. a) .|. rotateLU (reverseBits (mask .&. a)) n bitRev :: Data Index -> Vector a -> Vector a bitRev n = premap (bitr n) dt4 = zipWith (+.) (value [1,2,3,1 :: Float]) (value [4,-2,2,2]) -- Iterative fft1 :: Data Index -> Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) fft1 n as = bitRev n $ forLoop n as (\k -> twids0 (n-1-k) . bfly (n-1-k)) fft2 :: Data Index -> Vector (Data (Complex Float)) -> Vector (Data (Complex Float)) fft2 n as = bitRev n $ forLoop n as (\k -> bfly2 (n-1-k)) where bfly2 k as = indexed l ixf where l = length as ixf i = (testBit i k) ? (t*(b-a), a+b) where a = as ! i b = as ! (flipBit i k) t = tw (2^(k+1)) (i `mod` (2^k)) twids0 :: Data Index -> Vector1 (Complex Float) -> Vector1 (Complex Float) twids0 k as = indexed l ixf where l = length as ixf i = (testBit i k) ? (t*(as!i),as!i) where t = tw (2^(k+1)) (i `mod` (2^k)) twids1 :: Data Index -> Data Index -> Vector1 (Complex Float) -> Vector1 (Complex Float) twids1 n k as = indexed (length as) ixf where ixf i = (testBit i k) ? (t * (as!i), as!i) where t = tw (2^n) ((i `mod` (2^k)) .<<. (n-1-k)) twids2 :: Data Index -> Data Index -> Vector1 (Complex Float) -> Vector1 (Complex Float) twids2 n k as = indexed (length as) ixf where ts = force $ indexed (2^(n-1)) (tw (2^n)) ixf i = (testBit i k) ? (t * (as!i), as!i) where t = ts ! ((i `mod` (2^k)) .<<. (n-1-k)) fft3 :: Data Index -> Vector1 (Complex Float) -> Vector1 (Complex Float) fft3 n as = bitRev n $ forLoop n as (\k -> twids2 n (n-1-k) . bfly (n-1-k)) -- Decimation in Time (Earlier versions are Decimation in Frequency) fft4 :: Data Index -> Vector1 (Complex Float) -> Vector1 (Complex Float) fft4 n as = forLoop n (bitRev n as) (\k -> bfly k . twids2 n k)