module Data.HInduce.Classifier.DecisionTree
(
DTree(..)
, DTreeAlgebra(..)
, buildDTree
, foldD
, toDot
, Decider(..)
, DecideOrd(..)
, DecideSet(..)
, AutoDecide(..)
, genOrds
, genOrdsAvg
, genPair
, genMany
, avgF, avgI
) where
import Data.Convertible
import Data.HInduce.Classifier
import Data.List
import Data.List.HIUtils
import Data.Maybe
import Data.Ord
import Data.Ratio
import Control.Arrow
import Text.Layout
import Text.Layout.DisplayText
import Text.Layout.DisplayLatex
class Decider decider attr branch | decider -> attr branch where
decide :: decider -> attr -> branch
data (Ord t) => DecideOrd t = DecideOrd t
deriving (Show, Read)
data (Eq t) => DecideSet t = DecideSet [t]
deriving (Show, Read)
data Ixd decider = decider :!! Int
deriving (Show, Read)
instance (Ord attr) => Decider (DecideOrd attr) attr Bool where
decide (DecideOrd pivot) = (<= pivot)
instance (Eq attr) => Decider (DecideSet attr) attr Bool where
decide (DecideSet set) = (`elem` set)
instance (Decider decider attr branch) => Decider (Ixd decider) [attr] branch where
decide (decider :!! i) = decide decider . (!! i)
instance (Decider deca attra branch, Decider decb attrb branch) =>
Decider (Either deca decb) (attra, attrb) branch where
decide (Left dec) = decide dec . fst
decide (Right dec) = decide dec . snd
type DeciderGenerator attr decider = [attr] -> [decider]
class AutoDecide attr decider | attr -> decider where
autoDeciders :: [attr] -> [decider]
genOrds :: (Ord attr) => [attr] -> [DecideOrd attr]
genOrds window = map DecideOrd $ filter (/= maximum window) window
genOrdsAvg :: (Ord attr) => (attr -> attr -> attr) -> [attr] -> [DecideOrd attr]
genOrdsAvg favg window = map DecideOrd $ zipWith favg window' (tail window')
where window' = uniqSort window
genEq :: (Ord attr) => [attr] -> [DecideSet attr]
genEq window = map DecideSet $ subsequences $ uniqSort window
genPair :: DeciderGenerator attra decidera
-> DeciderGenerator attrb deciderb
-> DeciderGenerator (attra, attrb) (Either decidera deciderb)
genPair p q v = (map Left . p . map fst) v ++ (map Right . q . map snd) v
genMany :: DeciderGenerator attr decider
-> DeciderGenerator [attr] (Ixd decider)
genMany d window =
if not $ all (== length (head window)) $ map length window
then error "Refusing to generate deciders for variable length list attribute"
else [d' :!! i | i <- [0 .. (length$ head window)1]
, d' <- d (map (!! i) window)]
avgF :: (Fractional a) => a -> a -> a
avgF a b = (a+b) / 2
avgI :: (Integral a) => a -> a -> a
avgI a b = (a+b) `div` 2
instance AutoDecide Double (DecideOrd Double) where autoDeciders = genOrdsAvg avgF
instance AutoDecide Float (DecideOrd Float) where autoDeciders = genOrdsAvg avgF
instance AutoDecide Int (DecideOrd Int) where autoDeciders = genOrdsAvg avgI
instance AutoDecide Integer (DecideOrd Integer) where autoDeciders = genOrdsAvg avgI
instance (Integral a) => AutoDecide (Ratio a) (DecideOrd (Ratio a)) where autoDeciders = genOrdsAvg avgF
instance AutoDecide Char (DecideSet Char) where autoDeciders = genEq
instance AutoDecide [Char] (DecideSet [Char]) where autoDeciders = genEq
instance (AutoDecide a xa, AutoDecide b xb) =>
AutoDecide (a,b) (Either xa xb)
where autoDeciders = genPair autoDeciders autoDeciders
doSplit :: (Decider decider attr branch, Ord branch) =>
(x -> attr) -> decider -> [x] -> [(branch, [x])]
doSplit toattr dec = aggregateAL . map ((decide dec . toattr) &&& id)
doLabel :: (Ord label) => (x -> label) -> [x] -> [(label, [x])]
doLabel tolabel = aggregateAL . map (tolabel &&& id)
measureImpurity :: (Ord label) => (attr -> label) -> [(branch, [attr])] -> Double
measureImpurity tolabel = f . impurityAndCounts
where f :: [(Double, Int)] -> Double
f = sum . map (uncurry (*) . second fromIntegral)
impu = gini . map (length . snd) . doLabel tolabel
impurityAndCounts = map ((impu &&& length) . snd)
gini :: (Integral i, Fractional f) => [i] -> f
gini = sum . map (\x -> x * (1 x)) . relFreq
rateSplits :: (Decider decider attr branch,
Ord branch, Ord label) =>
DeciderGenerator attr decider ->
(x -> attr) ->
(x -> label) ->
[x] ->
[(decider, Double)]
rateSplits decGen toattr tolabel window = map (\dec -> (dec,) $ measureImpurity tolabel $ doSplit toattr dec window) . decGen . map toattr $ window
data DTree decider branch label = Node { dDecider :: decider
, dChildren :: [(branch, DTree decider branch label)]
}
| Leaf { dLabel :: label
}
deriving (Show, Eq)
data DTreeAlgebra decider branch label a =
DTreeAlgebra { fleaf :: label -> a
, fnode :: decider -> [(branch, a)] -> a
}
foldD :: DTreeAlgebra dec branch label a -> DTree dec branch label -> a
foldD (DTreeAlgebra fleaf _) (Leaf label) = fleaf label
foldD a@(DTreeAlgebra _ fnode) (Node dec children) = fnode dec $ map (second (foldD a)) children
predictAlgebra :: (Decider dec attr branch, Eq branch) =>
attr -> DTreeAlgebra dec branch label label
predictAlgebra newobservation = DTreeAlgebra { fleaf = fleaf, fnode = fnode }
where
fleaf = id
fnode dec children = error "Incomplete tree"
`fromMaybe` lookup (decide dec newobservation) children
predict :: (Decider dec attr branch, Eq branch) =>
attr -> DTree dec branch a -> a
predict a = foldD (predictAlgebra a)
instance (Decider decider attr branch, Eq branch) => Classifier (DTree decider branch label) attr label where
classify = flip predict
buildDTree' :: (Ord label, Ord branch,
AutoDecide attr dec,
Decider dec attr branch) =>
(x -> attr) ->
(x -> label) ->
[x] ->
DTree dec branch label
buildDTree' = buildDTree autoDeciders
buildDTree :: (Ord label, Ord branch, Decider decider attr branch) =>
DeciderGenerator attr decider ->
(x -> attr) ->
(x -> label) ->
[x] ->
DTree decider branch label
buildDTree decGen toAttr toLabel window = case rateSplits decGen toAttr toLabel window of
[] -> case window of
[] -> error "Empty window"
window -> Leaf . majority . map toLabel $ window
splits -> case uniqSort (map toLabel window) of
[x] -> Leaf x
_ -> let
(best, _) = minimumBy (comparing snd) splits
subwins = doSplit toAttr best window
in Node best $ map (second (buildDTree decGen toAttr toLabel)) subwins
toDot :: (Show decider, Show branch, Show label) =>
DTree decider branch label -> String
toDot t = "digraph G {\n" ++
foldD (DTreeAlgebra { fleaf = fleaf, fnode = fnode }) t "dtree" ++
"}\n"
where
fleaf label pfx = pfx++" [label="++show (show label)++"];\n"
fnode dec cs pfx = pfx++" [label="++show (show dec)++" shape=plaintext];\n"++
((concatMap (\(n, (key, f)) ->
let newpfx = pfx ++ "_c"++show n
in pfx++" -> " ++ newpfx ++
"[label="++show (show key)++"];\n" ++
f newpfx
)
$ zip nat0 cs) :: String)
instance (Show decider, Show branch, Show label) =>
Convertible (DTree decider branch label) DisplayText where
safeConvert = Right . DisplayText . printTree 0
where
spcs = flip replicate ' '
printTree indent Node {dDecider=dec,dChildren=children} =
spcs indent ++ "Node " ++ show dec ++ "\n" ++
concatMap (\(k, v) -> spcs (indent+2) ++ show k ++ "\n" ++
printTree (indent + 4) v)
children
printTree indent Leaf {dLabel=label} =
spcs indent ++ "Leaf " ++ show label ++ "\n"
instance (Show decider, Show branch, Show label) =>
Convertible (DTree decider branch label) DisplayLatex where
safeConvert t = Right $ DisplayLatex $ "\\begin{tikzpicture}\n\\" ++
(printTree 0 "" t) ++ ";\\end{tikzpicture}\n"
where
spcs = flip replicate ' '
printTree indent key Node {dChildren=children, dDecider=dec} =
spcs indent ++ "node{" ++ key ++ show dec ++ "}\n" ++
concatMap (\(k, v) -> "child {\n" ++
printTree (indent + 4) (show k++"\\\\\n") v ++
"}\n"
) children ++
"\n"
printTree indent key Leaf {dLabel=label} =
spcs indent ++ "\\node {" ++ show label ++ "}\n"