module Numeric.FFT.Utils
( omega, slicevecs, slicemvecs, primes, isPrime
, allFactors, factors
, primitiveRoot, invModN, log2, isPow2, dupperm, (%.%)
, compositions, makeComp, multisetPerms
, backpermuteM
) where
import Prelude hiding (all, concatMap, dropWhile, enumFromTo,
filter, head, length, map, maximum, null, reverse)
import qualified Prelude as P
import qualified Control.Monad as CM
import Control.Monad.ST
import Data.Bits
import Data.Complex
import Data.Vector.Unboxed
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed.Mutable as MV
import Data.List (nub)
import qualified Data.List as L
import Numeric.FFT.Types
omega :: Int -> Complex Double
omega n = cis (2 * pi / fromIntegral n)
slicevecs :: Int -> VCD -> VVCD
slicevecs m v = V.map (\i -> slice (i * m) m v) $
V.enumFromN 0 (length v `div` m)
slicemvecs :: Int -> MVCD a -> VMVCD a
slicemvecs m v = V.map (\i -> MV.slice (i * m) m v) $
V.enumFromN 0 (MV.length v `div` m)
primitiveRoot :: Int -> Int
primitiveRoot p
| isPrime p =
let tot = p 1
totpows = map (tot `div`) $ fromList $ nub $ toList $ allFactors tot
check n = all (/=1) $ map (expt p n) totpows
in fromIntegral $ head $ dropWhile (not . check) $ fromList [1..p1]
| otherwise = error "Attempt to take primitive root of non-prime value"
expt :: Int -> Int -> Int -> Int
expt n b pow = fromIntegral $ go pow
where bb = fromIntegral b
nb = fromIntegral n
go :: Int -> Integer
go p
| p == 0 = 1
| p `mod` 2 == 1 = (bb * go (p 1)) `mod` nb
| otherwise = let h = go (p `div` 2) in (h * h) `mod` nb
invModN :: Int -> Int -> Int
invModN n g = head $ filter (\iv -> (g * iv) `mod` n == 1) $ enumFromTo 1 (n1)
primes :: Integral a => [a]
primes = 2 : primes'
where primes' = sieve [3, 5 ..] 9 primes'
sieve (x:xs) q ps@ ~(p:t)
| x < q = x : sieve xs q ps
| True = sieve [n | n <- xs, rem n p /= 0] (P.head t^2) t
isPrime :: Integral a => a -> Bool
isPrime n = n `P.elem` P.takeWhile (<= n) primes
allFactors :: (Integral a, Unbox a) => a -> Vector a
allFactors n = fromList $ go n primes
where go cur pss@(p:ps)
| cur == p = [p]
| cur `mod` p == 0 = p : go (cur `div` p) pss
| otherwise = go cur ps
factors :: (Integral a, Unbox a) => a -> (a, Vector a)
factors n = let (lst, rest) = go n primes in (lst, fromList rest)
where go cur pss@(p:ps)
| cur == p = (p, [])
| cur `mod` p == 0 = let (lst, rest) = go (cur `div` p) pss
in (lst, p : rest)
| otherwise = go cur ps
log2 :: Int -> Int
log2 1 = 0
log2 n = 1 + log2 (n `div` 2)
isPow2 :: Int -> Bool
isPow2 1 = True
isPow2 n
| n `mod` 2 == 0 = isPow2 $ n `div` 2
| otherwise = False
dupperm :: Int -> VI -> VI
dupperm n p =
let sublen = length p
shift di = map (+(sublen * di)) p
in concatMap shift $ enumFromN 0 (n `div` sublen)
(%.%) :: VI -> VI -> VI
p1 %.% p2 = backpermute p2 p1
compositions :: Int -> V.Vector (Vector Int)
compositions 0 = V.empty
compositions n = let fs = allFactors n
in V.reverse $ V.map (makeComp fs) $ V.enumFromN 0 (2^(n1))
makeComp :: Vector Int -> Int -> Vector Int
makeComp fs i = fromList $ foldOps (toList fs) $ makeOps (length fs) i
where foldOps :: [Int] -> [Bool] -> [Int]
foldOps (f:fs) ops = go f fs ops
where go acc [] [] = [acc]
go acc (f:fs) (op:ops) = if op then go (acc * f) fs ops
else acc : go f fs ops
makeOps :: Int -> Int -> [Bool]
makeOps n i = P.replicate (n 1 P.length bs) False P.++ bs
where bs = P.dropWhile not $ P.reverse $
P.map (testBit i) [0..bitSize i1]
multisetPerms :: Vector Int -> [Vector Int]
multisetPerms idp = sidp : L.unfoldr step sidp
where sidp = fromList $ L.sort $ toList idp
step v = case permStep v of
Nothing -> Nothing
Just p -> Just (p, p)
permStep :: Vector Int -> Maybe (Vector Int)
permStep v =
if null ks
then Nothing
else let k = maximum ks
ls = filter (\i -> v ! k < v ! i) $ enumFromN 0 n
l = maximum ls
in Just $ revEnd k (swap k l)
where n = length v
ks = filter (\i -> v ! i < v ! (i+1)) $ enumFromN 0 (n1)
swap a b = generate n $ \i ->
if i == a then v ! b
else if i == b then v ! a
else v ! i
revEnd f vv = generate n $ \i ->
if i <= f
then vv ! i
else vv ! (n i + f)
backpermuteM :: Int -> VI -> MVCD s -> MVCD s -> ST s ()
backpermuteM n perm vin vout = do
CM.forM_ [0..n1] $ \i -> do
idx <- indexM perm i
x <- MV.unsafeRead vin idx
MV.unsafeWrite vout i x