module Data.Tree.LogTrees.FFTTree (
FFTTree (..), newFFTTree
) where
import Data.Complex
import Data.Tree
import Data.List
import Data.Newtypes.PrettyDouble (PrettyDouble(..))
import Data.Tree.LogTree
type FFTTree = GenericLogTree (Complex PrettyDouble)
instance LogTree FFTTree (Complex PrettyDouble) where
evalNode (Node (Just (x, wss), _, _, _) _, ms) = [foldl (*) x (zipWith getItem ms wss)]
where getItem m ws = ws !! m
evalNode (Node ( _, _, _, decimationType) children, ms) =
foldl (zipWith (+)) [0.0 | n <- [1..nodeLen]]
$ zipWith (zipWith (*)) subTransforms phasors
where subTransforms =
[ subCombFunc
$ map evalNode
[(child, ms ++ [m]) | m <- [0..(radix 1)]]
| child <- children
]
subCombFunc =
if decimationType == DIF then concat . transpose
else concat
childLen = length $ last(levels $ head children)
radix = length children
nodeLen = childLen * radix
phasors = [ [ cis((2.0) * pi / degree * fromIntegral r * fromIntegral k)
| k <- [0..(nodeLen 1)]]
| r <- [0..(radix 1)]]
degree | decimationType == DIF = fromIntegral radix
| otherwise = fromIntegral nodeLen
getTwiddles (Node ( _, _, _, decimationType) children) = calcTwiddles decimationType childLen radix
where childLen = length $ last(levels $ head children)
radix = length children
calcTwiddles decimationType childLen radix =
if decimationType == DIF
then [ [ cis((2.0) * pi / fromIntegral nodeLen * fromIntegral m * fromIntegral n)
| n <- [0..(childLen 1)]]
| m <- [0..(radix 1)]]
else [ [ 1.0 :+ 0.0
| n <- [0..(childLen 1)]]
| m <- [0..(radix 1)]]
where nodeLen = childLen * radix
getTwiddleStrs (Node ( _, _, _, decimationType) children) =
if decimationType == DIF
then map (map ((\str -> " [" ++ str ++ "]") . show)) $ getTwiddles (Node (Nothing, [], 0, decimationType) children)
else [["" | i <- [1..(length (last (levels child)))]] | child <- children]
getCompNodes (Node ( Just x, _, _, _) _) = []
getCompNodes (Node (Nothing, _, _, decimationType) children) =
[ [ (Sum, [ cis (2.0 * pi * k * r / degree)
| r <- map fromIntegral [0..(radix 1)]
]
)
| k <- map fromIntegral [childLen * r + m | r <- [0..(radix 1)]]
]
| m <- map fromIntegral [0..(childLen 1)]
] where childLen = fromIntegral $ length $ last(levels $ head children)
radix = length children
nodeLen = childLen * radix
degree | decimationType == DIF = fromIntegral radix
| otherwise = fromIntegral nodeLen
newFFTTree :: TreeBuilder FFTTree
newFFTTree = TreeBuilder buildMixedRadixTree