module Data.Tree.LogTree (
newTreeData, dotLogTree
, buildTree, newFFTTree
, getLevels, getFlatten, getEval
, modes, values
) where
import Data.Complex
import Data.Tree
import Data.List
import Text.Printf (printf, PrintfArg)
import Control.Monad.State.Lazy
import Data.Newtypes.PrettyDouble (PrettyDouble(..))
type GenericLogTree a = Tree (Maybe a, [Int], Int, Bool)
class (t ~ GenericLogTree a) => LogTree t a | t -> a where
evalNode :: t -> [a]
type FFTTree = GenericLogTree (Complex PrettyDouble)
instance LogTree FFTTree (Complex PrettyDouble) where
evalNode (Node (Just x, _, _, _) _) = [x]
evalNode (Node ( _, _, _, dif) children) =
foldl (zipWith (+)) [0.0 | n <- [1..nodeLen]]
$ zipWith (zipWith (*)) subTransforms phasors
where subTransforms =
if dif then
[ concat $ transpose
$ map evalNode
[ snd (coProd twiddle child)
| twiddle <- twiddles
]
| child <- children
]
else map (concat . replicate radix) subs
subs = map evalNode children
childLen = length $ last(levels $ head children)
nodeLen = childLen * radix
radix = length children
phasors = [ [ exp((0.0 :+ (1.0)) * 2.0 * pi / degree
* fromIntegral r * fromIntegral k)
| k <- [0..(nodeLen 1)]]
| r <- [0..(radix 1)]]
degree | dif = fromIntegral radix
| otherwise = fromIntegral nodeLen
twiddles = [ [ exp((0.0 :+ (1.0)) * 2.0 * pi / fromIntegral nodeLen
* fromIntegral m * fromIntegral n)
| n <- [0..(childLen 1)]]
| m <- [0..(radix 1)]]
coProd :: (Num a, t ~ GenericLogTree a) => [a] -> t -> ([a], t)
coProd [] (Node (Just x, offsets, skipFactor, dif) _) =
([], Node (Just x, offsets, skipFactor, dif) [])
coProd [a] (Node (Just x, offsets, skipFactor, dif) _) =
([], Node (Just (a * x), offsets, skipFactor, dif) [])
coProd (a:as) (Node (Just x, offsets, skipFactor, dif) _) =
(as, Node (Just (a * x), offsets, skipFactor, dif) [])
coProd as (Node (_, offsets, skipFactor, dif) children) =
(bs, Node (Nothing, offsets, skipFactor, dif) childProds)
where (bs, childProds) = foldl coProdStep (as, []) children
coProdStep :: (Num a, t ~ GenericLogTree a) => ([a], [t]) -> t -> ([a], [t])
coProdStep (as, ts) t = (bs, ts ++ [t'])
where (bs, t') = coProd as t
data TreeData a = TreeData {
modes :: [(Int, Bool)]
, values :: [a]
} deriving(Show)
newTreeData :: [(Int, Bool)]
-> [a]
-> TreeData a
newTreeData modes values = TreeData {
modes = modes
, values = values
}
newtype TreeBuilder t = TreeBuilder {
buildTree :: LogTree t a => TreeData a -> Either String t
}
newFFTTree :: TreeBuilder FFTTree
newFFTTree = TreeBuilder buildMixedRadixTree
buildMixedRadixTree :: TreeData a -> Either String (GenericLogTree a)
buildMixedRadixTree td = mixedRadixTree td_modes td_values
where td_modes = modes td
td_values = values td
mixedRadixTree :: [(Int, Bool)] -> [a] -> Either String (GenericLogTree a)
mixedRadixTree _ [] = Left "mixedRadixTree(): called with empty list."
mixedRadixTree _ [x] = return $ Node (Just x, [], 0, False) []
mixedRadixTree modes xs = mixedRadixRecurse 0 1 modes xs
mixedRadixRecurse :: Int -> Int -> [(Int, Bool)] -> [a] -> Either String (GenericLogTree a)
mixedRadixRecurse _ _ _ [] = Left "mixedRadixRecurse(): called with empty list."
mixedRadixRecurse myOffset _ _ [x] = return $ Node (Just x, [myOffset], 0, False) []
mixedRadixRecurse myOffset mySkipFactor modes xs
| product (map fst modes) == length xs =
do
children <- sequence [ mixedRadixRecurse childOffset childSkipFactor
(tail modes) subList
| (childOffset, subList) <- zip childOffsets subLists
]
return $ Node (Nothing, childOffsets, childSkipFactor, dif) children
| otherwise =
Left "mixedRadixTree(): Product of radices must equal length of input."
where subLists = [ [xs !! (offset + i * skipFactor) | i <- [0..(childLen 1)]]
| offset <- offsets
]
childSkipFactor | dif = mySkipFactor
| otherwise = mySkipFactor * radix
childOffsets | dif = [myOffset + (i * mySkipFactor * childLen) | i <- [0..(radix 1)]]
| otherwise = [myOffset + i * mySkipFactor | i <- [0..(radix 1)]]
skipFactor | dif = 1
| otherwise = radix
offsets | dif = [i * childLen | i <- [0..(radix 1)]]
| otherwise = [0..(radix 1)]
childLen = length xs `div` radix
radix = fst $ head modes
dif = snd $ head modes
data CompOp = Sum
| Prod
type CompNode a = ([(String, String)], [a], [CompOp])
dotLogTree :: (Show a, LogTree t a) => Either String t -> String
dotLogTree (Left msg) = header
++ "\"node0\" [label = \"" ++ msg ++ "\"]\n"
++ "}\n"
dotLogTree (Right tree) = header
++ evalState (dotLogTreeRecurse "0" tree) []
++ "}\n"
header = "digraph g { \n \
\ graph [ \n \
\ rankdir = \"RL\" \n \
\ splines = \"false\" \n \
\ ]; \n \
\ node [ \n \
\ fontsize = \"16\" \n \
\ shape = \"circle\" \n \
\ height = \"0.3\" \n \
\ ]; \n \
\ ranksep = \"1.5\";\n \
\ nodesep = \"0\";\n \
\ edge [ \n \
\ dir = \"back\" \n \
\ ];\n"
dotLogTreeRecurse :: (Show a, LogTree t a) => String -> t -> State [CompNode a] String
dotLogTreeRecurse nodeID (Node (Just x, offsets, _, _) _) =
return $ "\"node" ++ nodeID ++ "\" [label = \"<f0> "
++ "[" ++ show (head offsets) ++ "] " ++ show x
++ "\" shape = \"record\"];\n"
dotLogTreeRecurse nodeID (Node ( _, childOffsets, skip, dif) children) = do
let selfStr =
"\"node" ++ nodeID ++ "\" [label = \"<f0> "
++ show (head res)
++ concat [" | <f" ++ show k ++ "> " ++ show val
| (val, k) <- zip (tail res) [1..]]
++ "\" shape = \"record\"];\n"
childrenStr <-
liftM concat $
mapM (\(childID, child) ->
do curState <- get
let (childStr, newState) =
runState (dotLogTreeRecurse childID child) curState
put newState
return childStr
) [(childID, child) | (childID, child) <- zip childIDs children]
conStrs <-
forM [0..(num_elems 1)] (\k -> do
curState <- get
let ((compNodeID, compNodeDrawStr), newState) =
runState (getCompNodeID (k `mod` num_child_elems) childIDs)
curState
put newState
return $ compNodeDrawStr ++ drawConnection nodeID k compNodeID
)
return (selfStr ++ childrenStr ++ concat conStrs)
where num_elems = length children * num_child_elems
num_child_elems = length $ last(levels $ head children)
childIDs = [nodeID ++ show i | i <- [0..(length children 1)]]
res = evalNode $ Node (Nothing, childOffsets, skip, dif) children
drawConnection nodeID k compNodeID =
"\"node" ++ nodeID ++ "\":f" ++ show k
++ " -> \"node" ++ compNodeID ++ "\""
++ ";\n"
getCompNodeID :: Int -> [String] -> State [CompNode a] (String, String)
getCompNodeID k childIDs = do
compNodes <- get
let (newCompNodes, compNodeID, compNodeDrawStr) = fetchCompNodeID k childIDs compNodes
put newCompNodes
return (compNodeID, compNodeDrawStr)
fetchCompNodeID :: Int -> [String] -> [CompNode a]
-> ([CompNode a], String, String)
fetchCompNodeID k childIDs compNodes =
case findCompNode 0 inputList compNodes of
Just foundNodeID -> ( compNodes
, '1' : show foundNodeID
, ""
)
Nothing -> ( compNodes ++ [(inputList, coeffs, ops)]
, newNodeID
, drawStr
)
where drawStr = "\"node" ++ newNodeID ++ "\""
++ "[label = \".\""
++ ", shape = \"circle\""
++ ", height = \"0.25\""
++ ", fixedsize = \"true\""
++ "];\n"
++ unlines [ "\"node" ++ newNodeID ++ "\""
++ " -> "
++ "\"node" ++ fst input ++ "\""
++ ":f" ++ snd input
++ ":e"
++ ";\n"
| input <- inputList
]
newNodeID = '1' : show (length compNodes)
coeffs = []
ops = replicate (length inputList) Sum
where inputList = [ (nodeID, fieldID)
| (nodeID, fieldID) <- zip childIDs $
repeat $ show k
]
findCompNode :: Int -> [(String, String)] -> [CompNode a] -> Maybe Int
findCompNode _ _ [] = Nothing
findCompNode index inputList ((inputs, _, _):cns) =
if all (== True) [inputItem `elem` inputs | inputItem <- inputList]
then Just index
else findCompNode (index + 1) inputList cns
getValue :: LogTree t a => t -> Maybe a
getValue (Node (x, _, _, _) _) = x
getEval (Left msg) = []
getEval (Right tree) = evalNode tree
getLevels (Left msg) = []
getLevels (Right tree) = levels tree
getFlatten (Left msg) = []
getFlatten (Right tree) = levels tree