{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module FFT
( CoeffVec,
dftNaive,
fft,
fftMult,
fftTargetPoly,
interpolate,
inverseDft,
)
where
import Data.Field.Galois (GaloisField, pow)
import qualified Data.List as List
import Data.Poly (VPoly, monomial, toPoly)
import Data.Vector (fromList)
import Protolude
type CoeffVec f = [f]
type DFT f = [f]
evalPoly :: Num f => CoeffVec f -> f -> f
evalPoly coeffs x = foldr (\c rest -> c + x * rest) 0 coeffs
dftNaive ::
Num f =>
f ->
CoeffVec f ->
DFT f
dftNaive omega_n as = map (\i -> evalPoly as (omega_n ^ i)) [0 .. length as - 1]
split :: [a] -> ([a], [a])
split = foldr (\a (r1, r2) -> (a : r2, r1)) ([], [])
log2 :: Int -> Int
log2 x = floorLog + correction
where
floorLog = finiteBitSize x - 1 - countLeadingZeros x
correction =
if countTrailingZeros x < floorLog
then 1
else 0
fft ::
GaloisField k =>
(Int -> k) ->
CoeffVec k ->
DFT k
fft omega_n as =
case length as of
1 -> as
n ->
let (as0, as1) = split as
y0 = fft omega_n as0
y1 = fft omega_n as1
omegas = map (pow (omega_n (log2 n))) [0 .. n]
in combine y0 y1 omegas
where
combine y0 y1 omegas =
(\xs -> map fst xs ++ map snd xs)
$ map (\(yk0, yk1, currentOmega) -> (yk0 + currentOmega * yk1, yk0 - currentOmega * yk1))
$ List.zip3 y0 y1 omegas
inverseDft :: GaloisField k => (Int -> k) -> DFT k -> CoeffVec k
inverseDft primRootsUnity dft =
let n = fromIntegral . length $ dft
in map (/ n) $
fft (recip . primRootsUnity) dft
padToNearestPowerOfTwo :: Num f => [f] -> [f]
padToNearestPowerOfTwo [] = []
padToNearestPowerOfTwo xs = padToNearestPowerOfTwoOf (length xs) xs
padToNearestPowerOfTwoOf ::
Num f =>
Int ->
[f] ->
[f]
padToNearestPowerOfTwoOf i xs = xs ++ replicate padLength 0
where
padLength = nearestPowerOfTwo - length xs
nearestPowerOfTwo = bit $ log2 i
interpolate :: GaloisField k => (Int -> k) -> [k] -> VPoly k
interpolate primRoots pts = toPoly . fromList $ inverseDft primRoots (padToNearestPowerOfTwo pts)
fftMult :: GaloisField k => (Int -> k) -> CoeffVec k -> CoeffVec k -> CoeffVec k
fftMult primRoots l r = inverseDft primRoots $ zipWith (*) dftL dftR
where
n = 2 * max (length l) (length r)
paddedDft x = fft primRoots (padToNearestPowerOfTwoOf n x)
dftL = paddedDft l
dftR = paddedDft r
fftTargetPoly :: GaloisField k => (Int -> k) -> Int -> VPoly k
fftTargetPoly primRoots numRoots =
foldl' (*) (monomial 0 1) ((\i -> toPoly . fromList $ [- pow omega i, 1]) <$> [0 .. 2 ^ k - 1 :: Integer])
where
k = log2 numRoots
omega = primRoots k