{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Poly.Internal.Dense.DFT
( dft
, inverseDft
) where
import Prelude hiding (recip, fromIntegral)
import Control.Monad.ST
import Data.Bits hiding (shift)
import Data.Foldable
import Data.Semiring (Semiring(..), Ring(..), minus, fromIntegral)
import Data.Field (Field, recip)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
dft
:: (Ring a, G.Vector v a)
=> a
-> v a
-> v a
dft primRoot (xs :: v a)
| popCount nn /= 1 = error "dft: only vectors of length 2^n are supported"
| otherwise = go 0 0
where
nn = G.length xs
n = countTrailingZeros nn
roots :: v a
roots = G.iterateN
(1 `unsafeShiftL` (n - 1))
(\x -> x `seq` (x `times` primRoot))
one
go !offset !shift
| shift >= n = G.unsafeSlice offset 1 xs
| otherwise = runST $ do
let halfLen = 1 `unsafeShiftL` (n - shift - 1)
ys0 = go offset (shift + 1)
ys1 = go (offset + 1 `unsafeShiftL` shift) (shift + 1)
ys <- MG.new (halfLen `unsafeShiftL` 1)
let y00 = G.unsafeIndex ys0 0
y10 = G.unsafeIndex ys1 0
MG.unsafeWrite ys 0 $! y00 `plus` y10
MG.unsafeWrite ys halfLen $! y00 `minus` y10
forM_ [1..halfLen - 1] $ \k -> do
let y0 = G.unsafeIndex ys0 k
y1 = G.unsafeIndex ys1 k `times`
G.unsafeIndex roots (k `unsafeShiftL` shift)
MG.unsafeWrite ys k $! y0 `plus` y1
MG.unsafeWrite ys (k + halfLen) $! y0 `minus` y1
G.unsafeFreeze ys
inverseDft
:: (Field a, G.Vector v a)
=> a
-> v a
-> v a
inverseDft primRoot ys = G.map (`times` invN) $ dft (recip primRoot) ys
where
invN = recip $ fromIntegral $ G.length ys