{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

-- | Discrete Fourier transforms for polynomial interpolation
module FFT
  ( CoeffVec,
    dftNaive,
    fft,
    fftMult,
    fftTargetPoly,
    interpolate,
    inverseDft,
  )
where

import Data.Field.Galois (GaloisField, pow)
import qualified Data.List as List
import Data.Poly (VPoly, monomial, toPoly)
import Data.Vector (fromList)
import Protolude

-- | Polynomial represented as a coefficient vector, little-endian
type CoeffVec f = [f]

-- | Discrete Fourier transform. Can be interpreted as some polynomial
-- evaluated at certain roots of unity. (In our case the length of
-- these lists will be a power of two.)
type DFT f = [f]

-- | Evaluate a polynomial given by its coefficient vector
evalPoly :: Num f => CoeffVec f -> f -> f
evalPoly coeffs x = foldr (\c rest -> c + x * rest) 0 coeffs

-- | Naive discrete Fourier transformation performed by evaluating the
-- polynomial at the appropriate roots of unity.
dftNaive ::
  Num f =>
  -- | principal 2^k-th root of unity
  f ->
  -- | polynomial coefficients, length should be 2^k for
  -- some k
  CoeffVec f ->
  DFT f
dftNaive omega_n as = map (\i -> evalPoly as (omega_n ^ i)) [0 .. length as - 1]

-- | Split a list into a list containing the odd-numbered and one with
-- the even-numbered elements.
split :: [a] -> ([a], [a])
split = foldr (\a (r1, r2) -> (a : r2, r1)) ([], [])

-- | Calculate ceiling of log base 2 of an integer.
log2 :: Int -> Int
log2 x = floorLog + correction
  where
    floorLog = finiteBitSize x - 1 - countLeadingZeros x
    correction =
      if countTrailingZeros x < floorLog
        then 1
        else 0

-- | Fast Fourier transformation.
fft ::
  GaloisField k =>
  -- | function that gives for input n the principal (2^n)-th root of unity
  (Int -> k) ->
  -- | length should be n
  CoeffVec k ->
  DFT k
fft omega_n as =
  case length as of
    1 -> as
    n ->
      let (as0, as1) = split as
          y0 = fft omega_n as0
          y1 = fft omega_n as1
          omegas = map (pow (omega_n (log2 n))) [0 .. n]
       in combine y0 y1 omegas
  where
    combine y0 y1 omegas =
      (\xs -> map fst xs ++ map snd xs)
        $ map (\(yk0, yk1, currentOmega) -> (yk0 + currentOmega * yk1, yk0 - currentOmega * yk1))
        $ List.zip3 y0 y1 omegas

-- | Inverse discrete Fourier transformation, uses FFT.
inverseDft :: GaloisField k => (Int -> k) -> DFT k -> CoeffVec k
inverseDft primRootsUnity dft =
  let n = fromIntegral . length $ dft
   in map (/ n) $
        fft (recip . primRootsUnity) dft

-- | Append minimal amount of zeroes until the list has a length which
-- is a power of two.
padToNearestPowerOfTwo :: Num f => [f] -> [f]
padToNearestPowerOfTwo [] = []
padToNearestPowerOfTwo xs = padToNearestPowerOfTwoOf (length xs) xs

-- | Given n, append zeroes until the list has length 2^n.
padToNearestPowerOfTwoOf ::
  Num f =>
  -- | n
  Int ->
  -- | list which should have length <= 2^n
  [f] ->
  -- | list which will have length 2^n
  [f]
padToNearestPowerOfTwoOf i xs = xs ++ replicate padLength 0
  where
    padLength = nearestPowerOfTwo - length xs
    nearestPowerOfTwo = bit $ log2 i

-- | Create a polynomial that goes through the given values.
interpolate :: GaloisField k => (Int -> k) -> [k] -> VPoly k
interpolate primRoots pts = toPoly . fromList $ inverseDft primRoots (padToNearestPowerOfTwo pts)

-- | Multiply polynomials using FFT
fftMult :: GaloisField k => (Int -> k) -> CoeffVec k -> CoeffVec k -> CoeffVec k
fftMult primRoots l r = inverseDft primRoots $ zipWith (*) dftL dftR
  where
    n = 2 * max (length l) (length r)
    paddedDft x = fft primRoots (padToNearestPowerOfTwoOf n x)
    dftL = paddedDft l
    dftR = paddedDft r

-- XXX make this actually go fast
-- polyWithZeroesAt
--    :: Fractional f
--    => (Int -> f)
--    -> [f]
--    -> CoeffVec f
-- polyWithZeroesAt primRoots
--   = foldl' (fftMult primRoots) [1]
--     . map (\xcoord -> [-xcoord, 1])

-- XXX make this actually use FFT mult
fftTargetPoly :: GaloisField k => (Int -> k) -> Int -> VPoly k
fftTargetPoly primRoots numRoots =
  foldl' (*) (monomial 0 1) ((\i -> toPoly . fromList $ [- pow omega i, 1]) <$> [0 .. 2 ^ k - 1 :: Integer])
  where
    k = log2 numRoots
    omega = primRoots k