{-# LANGUAGE TemplateHaskell #-}
module Prediction where
import Data.Maybe
import Data.Either
import Data.Graph
import Data.List hiding (transpose)
import Data.Matrix hiding ((!), trace)
import Data.Map.Strict (Map, (!))
import qualified Data.Map.Strict as Map
import Test.QuickCheck
import TypeInfo
import Countable
import Debug.Trace
type Size = Int
type FreqMap = Map Name Int
builtInFreqs :: FreqMap
builtInFreqs = Map.fromList
[ ('True, 1), ('False, 1)
, ('Left, 1), ('Right, 1)
]
initMap :: TypeEnv -> FreqMap
initMap = Map.fromList . map setInitialFreq . consList
where setInitialFreq cn
| Map.member cn builtInFreqs = (cn, builtInFreqs ! cn * 100)
| otherwise = (cn, 100)
type ProbMap = Map Name Double
showMap :: (Show a, Show b) => Map a b -> String
showMap m = intercalate "\n" (map showElem (Map.toList m))
where showElem e = " * " ++ show e
filterKeys :: (Name -> Bool) -> Map Name b -> Map Name b
filterKeys f = Map.filterWithKey (const . f)
normalize :: TypeEnv -> FreqMap -> ProbMap
normalize env freqMap = Map.mapWithKey freqRatio freqMap
where
freqRatio cn cfreq = fromIntegral cfreq / fromIntegral (freqSum cn)
freqSum cn = Map.foldr (+) 0 (filterKeys (isSibling env cn) freqMap)
normalizeTerminals :: TypeEnv -> FreqMap -> ProbMap
normalizeTerminals env freqMap = Map.mapWithKey freqRatio terminalsMap
where
terminalsMap = filterKeys (isTerminal env) freqMap
freqRatio cn cfreq = fromIntegral cfreq / fromIntegral (freqSum cn)
freqSum cn = Map.foldr (+) 0 (filterKeys (isSibling env cn) terminalsMap)
genGWMatrix :: TypeEnv -> ProbMap -> Matrix Double
genGWMatrix env probMap = matrix size size genElem
where
size = length env
genElem (m, n) = sum $ map multProb $ occsFromTo (env!!(m-1)) (env!!(n-1))
multProb (cn, occs) = probMap ! cn * fromIntegral occs
occsFromTo from to = map (conOccurrences to) (tcons from)
conOccurrences to con = (cname con, occurrences (tsig to) con)
predict :: TypeEnv -> Size -> FreqMap -> ProbMap
predict env size freqs = prediction
where
allProbs = normalize env freqs
termProbs = normalizeTerminals env freqs
rootType = env !! 0
isBranchingType t = t == rootType || any rec (tcons t)
(branchingTypes, leafTypes) = partition isBranchingType env
bct = tsig . conType branchingTypes
lct = tsig . conType leafTypes
m !$ cn = Map.findWithDefault 0 cn m
branchingTypesExp = Map.unionWith (+) brFirstLevels brLastLevel
branchingProbs = filterKeys isBranchingTypeCon allProbs
branchingTermProbs = filterKeys isBranchingTypeCon termProbs
isBranchingTypeCon cn = cn `elem` consList branchingTypes
branchingSigs = typeSigs branchingTypes
mT = genGWMatrix branchingTypes branchingProbs
ez0 = fromList 1 (length branchingTypes) (1 : repeat 0)
genLevel 0 = ez0
genLevel k = ez0 * (mT^k)
brFirstLevels = Map.mapWithKey multTypeExp branchingProbs
where
multTypeExp cn cp
| length branchingTypes == 1 && mT' /= 1
= cp * ((1 - mT' ^ size) / (1 - mT'))
| otherwise = cp * typeExp ! bct cn
mT' = getElem 1 1 mT
typeExp = Map.fromList $ zip branchingSigs (toList predMatrix)
predMatrix = foldr1 (+) (map genLevel [0..size-1])
brLastLevel = Map.mapWithKey sumTermExp branchingTermProbs
where
sumTermExp tn tp
= sum [ tp
* allProbs ! cname con
* fromIntegral (occurrences (bct tn) con)
* prevLvlExp ! bct (cname con)
| con <- concatMap tcons branchingTypes ]
prevLvlExp = Map.fromList $ zip branchingSigs (toList (genLevel (size-1)))
prediction = addLeafTypesExp branchingTypesExp sortedLeafTypeCons
addLeafTypesExp pred [] = pred
addLeafTypesExp pred (cn:cns)
= addLeafTypesExp (Map.insert cn (sumOccurrences pred cn) pred) cns
sumOccurrences pred cn
= sum [ allProbs ! cn
* pred !$ cname con
* fromIntegral (occurrences (lct cn) con)
| con <- allCons ]
allCons = concatMap tcons env
leafTypeCons = concatMap tcons leafTypes
generatorsOf cn = [ cname con | con <- allCons, any (== lct cn) (cargs con) ]
sortedLeafTypeCons = reverse (map (extractCName . gvert) (topSort graph))
(graph, gvert) = graphFromEdges' leafTypeDeps
leafTypeDeps = map createVertex leafTypeCons
extractCName (_, cn, _) = cn
createVertex con = ((), cname con, generatorsOf (cname con))
confirm :: (Countable a) => Size -> Gen a -> IO ()
confirm size arb = do
let samples = 100000
values <- sequence (replicate samples (generate (resize size arb)))
let consCount = Map.unionsWith (+) (map count values)
consAvg = Map.map (\c -> fromIntegral c / fromIntegral samples) consCount
putStrLn (showMap consAvg)