{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} -- | Discrete Fourier transforms for polynomial interpolation module FFT ( CoeffVec, dftNaive, fft, fftMult, fftTargetPoly, interpolate, inverseDft, ) where import Data.Field.Galois (GaloisField, pow) import qualified Data.List as List import Data.Poly (VPoly, monomial, toPoly) import Data.Vector (fromList) import Protolude -- | Polynomial represented as a coefficient vector, little-endian type CoeffVec f = [f] -- | Discrete Fourier transform. Can be interpreted as some polynomial -- evaluated at certain roots of unity. (In our case the length of -- these lists will be a power of two.) type DFT f = [f] -- | Evaluate a polynomial given by its coefficient vector evalPoly :: Num f => CoeffVec f -> f -> f evalPoly coeffs x = foldr (\c rest -> c + x * rest) 0 coeffs -- | Naive discrete Fourier transformation performed by evaluating the -- polynomial at the appropriate roots of unity. dftNaive :: Num f => -- | principal 2^k-th root of unity f -> -- | polynomial coefficients, length should be 2^k for -- some k CoeffVec f -> DFT f dftNaive omega_n as = map (\i -> evalPoly as (omega_n ^ i)) [0 .. length as - 1] -- | Split a list into a list containing the odd-numbered and one with -- the even-numbered elements. split :: [a] -> ([a], [a]) split = foldr (\a (r1, r2) -> (a : r2, r1)) ([], []) -- | Calculate ceiling of log base 2 of an integer. log2 :: Int -> Int log2 x = floorLog + correction where floorLog = finiteBitSize x - 1 - countLeadingZeros x correction = if countTrailingZeros x < floorLog then 1 else 0 -- | Fast Fourier transformation. fft :: GaloisField k => -- | function that gives for input n the principal (2^n)-th root of unity (Int -> k) -> -- | length should be n CoeffVec k -> DFT k fft omega_n as = case length as of 1 -> as n -> let (as0, as1) = split as y0 = fft omega_n as0 y1 = fft omega_n as1 omegas = map (pow (omega_n (log2 n))) [0 .. n] in combine y0 y1 omegas where combine y0 y1 omegas = (\xs -> map fst xs ++ map snd xs) $ map (\(yk0, yk1, currentOmega) -> (yk0 + currentOmega * yk1, yk0 - currentOmega * yk1)) $ List.zip3 y0 y1 omegas -- | Inverse discrete Fourier transformation, uses FFT. inverseDft :: GaloisField k => (Int -> k) -> DFT k -> CoeffVec k inverseDft primRootsUnity dft = let n = fromIntegral . length $ dft in map (/ n) $ fft (recip . primRootsUnity) dft -- | Append minimal amount of zeroes until the list has a length which -- is a power of two. padToNearestPowerOfTwo :: Num f => [f] -> [f] padToNearestPowerOfTwo [] = [] padToNearestPowerOfTwo xs = padToNearestPowerOfTwoOf (length xs) xs -- | Given n, append zeroes until the list has length 2^n. padToNearestPowerOfTwoOf :: Num f => -- | n Int -> -- | list which should have length <= 2^n [f] -> -- | list which will have length 2^n [f] padToNearestPowerOfTwoOf i xs = xs ++ replicate padLength 0 where padLength = nearestPowerOfTwo - length xs nearestPowerOfTwo = bit $ log2 i -- | Create a polynomial that goes through the given values. interpolate :: GaloisField k => (Int -> k) -> [k] -> VPoly k interpolate primRoots pts = toPoly . fromList $ inverseDft primRoots (padToNearestPowerOfTwo pts) -- | Multiply polynomials using FFT fftMult :: GaloisField k => (Int -> k) -> CoeffVec k -> CoeffVec k -> CoeffVec k fftMult primRoots l r = inverseDft primRoots $ zipWith (*) dftL dftR where n = 2 * max (length l) (length r) paddedDft x = fft primRoots (padToNearestPowerOfTwoOf n x) dftL = paddedDft l dftR = paddedDft r -- XXX make this actually go fast -- polyWithZeroesAt -- :: Fractional f -- => (Int -> f) -- -> [f] -- -> CoeffVec f -- polyWithZeroesAt primRoots -- = foldl' (fftMult primRoots) [1] -- . map (\xcoord -> [-xcoord, 1]) -- XXX make this actually use FFT mult fftTargetPoly :: GaloisField k => (Int -> k) -> Int -> VPoly k fftTargetPoly primRoots numRoots = foldl' (*) (monomial 0 1) ((\i -> toPoly . fromList $ [- pow omega i, 1]) <$> [0 .. 2 ^ k - 1 :: Integer]) where k = log2 numRoots omega = primRoots k