import qualified Prelude import Feldspar hiding (cos, cycle) import Feldspar.Stream (Stream, recurrenceI, cycle) import qualified Feldspar.Stream as S import Feldspar.Compiler -------------------------------------------------- -- Missing functions -- -- These ad hoc functions are for various reasons -- not directly included in Feldspar. -------------------------------------------------- -- Will appear as a call to "intToFloat" in the generated code. intToFloat :: Data Int -> Data Float intToFloat = function "intToFloat" (\_ -> universal) (fromInteger.toInteger) -- Will appear as a call to "cos" in the generated code. cos :: Data Float -> Data Float cos = function "cos" (\_ -> universal) Prelude.cos toList :: Int -> Vector (Data a) -> [Data a] toList n v@(Indexed l ix) = Prelude.map (v!) (Prelude.map value [0..n-1]) -- This function generates very inefficient code. fromList :: Storable a => [Data a] -> Vector (Data a) fromList ls = unfreezeVector (value len) (loop 1 (parallel (value len) (const (Prelude.head ls)))) where loop i arr | i Prelude.< len = loop (i+1) (setIx arr (value i) (ls !! i)) | otherwise = arr len = Prelude.length ls -------------------------------------------------- -- Examples -------------------------------------------------- square :: Data Int -> Data Int square x = x*x sumSq :: Data Int -> Data Int sumSq n = (sum . map square) (1...n) -- Convolver conv1D :: DVector Float -> DVector Float -> DVector Float conv1D kernel = map (scalarProd kernel . reverse) . inits1 modulus :: Data Int -> Data Int -> Data Int modulus a b = while (>=b) (subtract b) a powersOfTwo :: Data [Int] powersOfTwo = parallel 8 (\i -> 2^i) -------------------------------------------------- -- Discrete cosine transform -------------------------------------------------- dct2 :: DVector Float -> DVector Float dct2 xn = mat ** xn where mat = indexedMat (length xn) (length xn) (\k l -> dct2nkl (length xn) k l) dct2nkl :: Data Int -> Data Int -> Data Int -> Data Float dct2nkl n k l = cos ( (k' * (2*l' + 1)*3.14)/(2*n') ) where (n',k',l') = (intToFloat n, intToFloat k, intToFloat l) -------------------------------------------------- -- Sorter -------------------------------------------------- minP :: (Storable a, Ord a) => Data a -> Data a -> Data a minP = function2 "min" (\_ _ -> universal) Prelude.min maxP :: (Storable a, Ord a) => Data a -> Data a -> Data a maxP = function2 "max" (\_ _ -> universal) Prelude.max comp :: (Storable a, Ord a) => (Data a,Data a) -> (Data a, Data a) comp (a,b) = (min a b, max a b) cswap :: Data [Int] -> (Data Int, Data Int) -> Data [Int] cswap as (l,r) = setIx (setIx as r mx) l mn where (mn,mx) = comp (as!l,as!r) ones k = 2^k-1 -- j ones, shifted k bits to the left onesZeros :: Data Int -> Data Int -> Data Int onesZeros j k = shiftL (ones j) k allones = 2^31-1 -- zero out rightmost i bits of k zeroBitsR :: Data Int -> Data Int -> Data Int zeroBitsR i k = k .&. (shiftL allones i) -- shifts bits j and upwards leftwards one and sets bit j to zero setBitAndShift :: Data Int -> Data Int -> Data Int setBitAndShift j k = k + (zeroBitsR j k) swapsT n = indexed (n+1) (\j -> (indexed (j+1) (\k -> swapcol n (n-j) (j-k)))) swapcol n i v = indexed (2^n) (\k -> g (setBitAndShift (i+v) k)) where g k = (k, xor (onesZeros (v+1) i) k) sort0 :: Data Length -> Data [Int] -> Data [Int] sort0 n as = fold (fold (fold cswap)) as (swapsT n) -------------------------------------------------- -- Blake -------------------------------------------------- type MessageBlock = DVector Unsigned32 -- 0..15 type Round = Data Int type State = Matrix Unsigned32 -- 0..3 0..3 co :: DVector Unsigned32 co = vector [0x243F6A88,0x85A308D3,0x13198A2E,0x03707344, 0xA4093822,0x299F31D0,0x082EFA98,0xEC4E6C89, 0x452821E6,0x38D01377,0xBE5466CF,0x34E90C6C, 0xC0AC29B7,0xC97C50DD,0x3F84D5B5,0xB5470917] sigma :: Matrix Int sigma = matrix [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15] ,[14,10,4,8,9,15,13,6,1,12,0,2,11,7,5,3] ,[11,8,12,0,5,2,15,13,10,14,3,6,7,1,9,4] ,[7,9,3,1,13,12,11,14,2,6,5,10,4,0,15,8] ,[9,0,5,7,2,4,10,15,14,1,11,12,6,8,3,13] ,[2,12,6,10,0,11,8,3,4,13,7,5,15,14,1,9] ,[12,5,1,15,14,13,4,10,0,7,6,3,9,2,8,11] ,[13,11,7,14,12,1,3,9,5,0,15,4,8,6,2,10] ,[6,15,14,9,11,3,0,8,12,2,13,7,1,4,10,5] ,[10,2,8,4,7,6,1,5,15,11,9,14,3,12,13,0] ] diagonals :: Matrix a -> Matrix a diagonals m = map (diag m) (0 ... (length (head m) - 1)) -- Return the i'th diagonal diag :: Matrix a -> Data Int -> Vector (Data a) diag m i = zipWith lookup m (i ... (l + i)) where l = length m - 1 lookup v i = v ! (i `mod` length v) invDiagonals :: Storable a => Matrix a -> Matrix a invDiagonals m = zipWith shiftVectorR (0 ... (length m - 1)) (transpose m) shiftVectorR :: Computable a => Data Int -> Vector a -> Vector a shiftVectorR i v = reverse (drop i rev ++ take i rev) where rev = reverse v blakeRound :: MessageBlock -> State -> Round -> State blakeRound m state r = (invDiagonals . zipWith (g m r) (4 ... 7) . diagonals . transpose . zipWith (g m r) (0 ... 3) . transpose) state g :: MessageBlock -> Round -> Data Int -> DVector Unsigned32 -> DVector Unsigned32 g m r i v = fromList [a'',b'',c'',d''] where [a,b,c,d] = toList 4 v a' = a + b + (m!(sigma!r!(2*i)) ⊕ (co!(sigma!r!(2*i+1)))) d' = (d ⊕ a') >> 16 c' = c + d' b' = (b ⊕ c') >> 12 a'' = a' + b' + (m!(sigma!r!(2*i+1)) ⊕ (co!(sigma!r!(2*i)))) d'' = (d' ⊕ a'') >> 8 c'' = c' + d'' b'' = (b' ⊕ c'') >> 7 -------------------------------------------------- -- Streams -------------------------------------------------- -- IIR filter iir :: Data Float -> DVector Float -> DVector Float -> Stream (Data Float) -> Stream (Data Float) iir a0 a b input = recurrenceI (replicate q 0) input (replicate p 0) (\x y -> 1 / a0 * ( sum (indexed p (\i -> b!i * x!(p-i))) - sum (indexed q (\j -> a!j * y!(q-j)))) ) where p = length b q = length a -------------------------------------------------- -- Tests -------------------------------------------------- test1 = eval (sumSq 10) test2 = printCore sumSq test3 = icompile' sumSq "sumSq" defaultOptions test4 = printCore conv1D test5 = eval (modulus 22 6) test6 = eval powersOfTwo test7 = printCore dct2 test8 = printCore sort0 test9 = printCore iirVec where -- Wrapper code to make it operate on vectors iirVec a0 a b = S.take 100 . iir a0 a b . cycle