{- CAO Compiler Copyright (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . -} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} {- Module : $Header$ Description : Indistinguishable functions. Copyright : (C) 2014 Cryptography and Information Security Group, HASLab - INESC TEC and Universidade do Minho License : GPL Maintainer : Paulo Silva Stability : experimental Portability : non-portable -} module Language.CAO.Transformation.Indist ( mkIndistFun , indist ) where import Control.Applicative import Data.List import qualified Data.Map as M import Data.Set ( Set ) import qualified Data.Set as Set import Data.Maybe ( catMaybes ) import qualified Data.Traversable as T import qualified Data.Foldable as F import Language.CAO.Common.Error import Language.CAO.Common.Fresh import Language.CAO.Common.Monad import Language.CAO.Common.Outputable import Language.CAO.Common.SrcLoc import Language.CAO.Common.State import Language.CAO.Common.Var import Language.CAO.Syntax import Language.CAO.Syntax.Utils ( getVars, getLVars, sameKind, fvs, defVar ) import Language.CAO.Analysis.CFG import Language.CAO.Analysis.SsaBack ( introduceDefs, rmVars ) -------------------------------------------------------------------------------- -- * Indistinguishable functions -------------------------------------------------------------------------------- -- | Apply countermeasures to two function definitions mkIndistFun :: CaoMonad m => String -> String -> [CaoCFG] -> m [CaoCFG] mkIndistFun (mkFunName -> fn1) (mkFunName -> fn2) cfgs | Just ((p1,p2), (cfg1, cfg2), cfgs2) <- mcfgs, valid cfg1, valid cfg2 = do (cfg1', cfg2') <- mkIndistCfg (fn1, cfg1) (fn2, cfg2) return $ insertPos [(p1, cfg1'), (p2, cfg2')] cfgs2 | otherwise = indistWarn fn1 fn2 >> return cfgs where mcfgs :: Maybe ((Int, Int),(CaoCFG, CaoCFG), [CaoCFG]) mcfgs = do (p1, cfg1, cfgs') <- lookupDef fn1 cfgs (p2, cfg2, cfgs'') <- lookupDef fn2 cfgs' return ((p1, p2), (cfg1, cfg2), cfgs'') -- TODO: stub valid _ = True mkIndistCfg :: CaoMonad m => (Name, CaoCFG) -> (Name, CaoCFG) -> m (CaoCFG, CaoCFG) mkIndistCfg (name1, cfg1) (name2, cfg2) | Just ((n1, n2), (b1, b2), (c1, c2)) <- mcfgs = do (b1', b2') <- indist b1 b2 let cfg1' = introduceDefs $ rmVars $ cfg1 { blocks = M.insert n1 (b1', c1) bcfg1 } cfg2' = introduceDefs $ rmVars $ cfg2 { blocks = M.insert n2 (b2', c2) bcfg2 } mkIndistDecls cfg1' cfg2' | otherwise = indistWarn name1 name2 >> return (cfg1, cfg2) where bcfg1 = blocks $ removeSsaDecl cfg1 bcfg2 = blocks $ removeSsaDecl cfg2 mcfgs = do (n1, b1, c1) <- innerNode entryNode [exitNode] bcfg1 (n2, b2, c2) <- innerNode entryNode [exitNode] bcfg2 return ((n1,n2), (b1,b2), (c1,c2)) mkIndistDecls :: CaoMonad m => CaoCFG -> CaoCFG -> m (CaoCFG, CaoCFG) mkIndistDecls cfg1 cfg2 | Just ((n1, n2), (b1, b2), (c1, c2)) <- mcfgs = do (b1', b2') <- indistDecls b1 b2 return ( cfg1 { blocks = M.insert n1 (b1', c1) bcfg1 } , cfg2 { blocks = M.insert n2 (b2', c2) bcfg2 } ) | otherwise = return (cfg1, cfg2) where bcfg1 = blocks cfg1 bcfg2 = blocks cfg2 mcfgs :: Maybe ((NodeId, NodeId), (BasicBlock, BasicBlock), (Connections, Connections)) mcfgs = do (n1, b1, c1) <- innerNode entryNode [exitNode] bcfg1 (n2, b2, c2) <- innerNode entryNode [exitNode] bcfg2 return ((n1, n2), (b1, b2), (c1, c2)) -- Pre: all operations are already "indistinguishable". indistDecls :: CaoMonad m => BasicBlock -> BasicBlock -> m (BasicBlock, BasicBlock) indistDecls b1 b2 = do (db1', db2') <- case ldb1 of _ | ldb1 == ldb2 -> return (db1, db2) | ldb1 > ldb2 -> do db2'' <- mapM dummyDecl (drop ldb2 db1) return (db1, db2 ++ db2'') | otherwise -> do -- ldb2 > ldb1 db1'' <- mapM dummyDecl (drop ldb1 db2) return (db1 ++ db1'', db2) return (db1' ++ rb1, db2' ++ rb2) where (db1, rb1) = partition isDecl b1 (db2, rb2) = partition isDecl b2 ldb1 = length db1 ldb2 = length db2 isDecl (L _ (VDecl _)) = True isDecl _ = False dummyDecl :: CaoMonad m => LStmt Var -> m (LStmt Var) dummyDecl (unLoc -> VDecl vd) = genLoc . VDecl <$> T.mapM (freshVar Local . varType) vd dummyDecl s = error $ "Language.CAO.CaoSSA.dummyDecl: failed to create a dummy\ \operation of this kind!" ++ showPpr s innerNode :: NodeId -> [NodeId] -> M.Map NodeId (BasicBlock, Connections) -> Maybe (NodeId, BasicBlock, Connections) innerNode e next m | Just (_, [n]) <- M.lookup e m -- entry , Just (b, rest) <- M.lookup n m -- inner , rest == next -- connections are OK, TODO:ordering = Just (n, b, rest) | otherwise = Nothing lookupDef :: Name -> [CaoCFG] -> Maybe (Int, CaoCFG, [CaoCFG]) lookupDef n cfgs | ([(i,cfg)], cfgs') <- partitionPos hasName cfgs = Just (i,cfg, cfgs') | otherwise = Nothing where hasName = (== [n]) . map varName . defVar . definition partitionPos :: (a -> Bool) -> [a] -> ([(Int, a)], [a]) partitionPos f lst = partitionPosAcc 0 ([],[]) lst where partitionPosAcc _ r [] = r partitionPosAcc a (ys,ns) (x:xs) | f x = partitionPosAcc (a + 1) ((a,x):ys, ns ) xs | otherwise = partitionPosAcc (a + 1) (ys , x:ns) xs insertPos :: [(Int, a)] -> [a] -> [a] insertPos lst xs = foldl' (\b (i, x) -> insertAt i x b) xs $ sortBy compareFst lst where compareFst (i1,_) (i2,_) = compare i1 i2 insertAt :: Int -> a -> [a] -> [a] insertAt 0 x lst = x:lst insertAt _ x [] = [x] insertAt n x (y:ys) = y:insertAt (n - 1) x ys indistWarn :: CaoMonad m => Name -> Name -> m () indistWarn v1 = caoWarning defSrcLoc . IndistFail v1 -- | Turn two CFG basic blocks into indistinguishable -- -- Notes: (b1', b2') <- b1 `indist` b2 indist :: CaoMonad m => BasicBlock -> BasicBlock -> m (BasicBlock, BasicBlock) indist b1 b2 = mkIndist (mkStmtGraph b1) (mkStmtGraph b2) -- | Algorithm for indistinguishable functions -- TODO: check best place for dummy ops mkIndist :: CaoMonad m => StmtGraph -> StmtGraph -> m (BasicBlock, BasicBlock) mkIndist g1 g2 = do tr <- doMkSTree [SN { cost = 0 , stmt1 = [] , stmt2 = [] , rest1 = g1 , rest2 = g2 }] let (r:_) = sortBy (\(c1,_,_) (c2,_,_) -> compare c1 c2) tr return $ (\(_,x,y) -> (x,y)) r -------------------------------------------------------------------------------- -- ** Solution -------------------------------------------------------------------------------- data SNode = SN { cost :: Int , stmt1 :: BasicBlock , stmt2 :: BasicBlock , rest1 :: StmtGraph , rest2 :: StmtGraph } fCost :: SNode -> Int fCost sn = cost sn + fDist (rest1 sn) + fDist (rest2 sn) cmpNd :: SNode -> SNode -> Ordering cmpNd sn1 sn2 = compare (fCost sn1) (fCost sn2) {- Not used but can be useful in the future nextNode :: SNode -> SNode -> SNode nextNode (SN sc b1 b2 _ _) (SN sc2 s1 s2 g1' g2') = SN (sc + sc2) (s1 ++ b1) (s2 ++ b2) g1' g2' -} doMkSTree :: CaoMonad m => [SNode] -> m [(Int, BasicBlock, BasicBlock)] doMkSTree [] = return [] doMkSTree es@(sn:xs) | nullG g1 && nullG g2 = do rs <- doMkSTree xs return $ (cost sn, reverse $ stmt1 sn, reverse $ stmt2 sn):rs | otherwise = do alts <- sortBy cmpNd . concat <$> mapM nextNodes es doMkSTree (take 200 alts) --- $ concatMap (\e -> map (nextNode e) alts) es where g1 = rest1 sn g2 = rest2 sn fDist :: StmtGraph -> Int fDist (SGraph w _) = w nextNodes :: CaoMonad m => SNode -> m [SNode] nextNodes sn = (sn' ++) <$> dummys where g1 = rest1 sn g2 = rest2 sn altsG1 = anyStmt g1 altsG2 = anyStmt g2 sn' = map mkAlt $ combinations altsG1 altsG2 mkAlt ((s1,g1'),(s2,g2')) = sn { stmt1 = s1:(stmt1 sn) , stmt2 = s2:(stmt2 sn) , rest1 = g1' , rest2 = g2' } dummys = do d1 <- mapM addDL $ filter (not . isRet . fst) altsG1 d2 <- mapM addDR $ filter (not . isRet . fst) altsG2 return $ d1 ++ d2 addDL (s, g) | not (needsDummy s) = return $ sn { stmt1 = s :(stmt1 sn) , rest1 = g } | otherwise = do (n, vs, s') <- mkDummyOp s F.mapM_ storeTmpVar vs return $ sn { cost = (cost sn) + n , stmt1 = s :(stmt1 sn) , stmt2 = s':(stmt2 sn) , rest1 = g } addDR (s, g) | not (needsDummy s) = return $ sn { stmt2 = s :(stmt2 sn) , rest2 = g } | otherwise = do (n, vs, s') <- mkDummyOp s F.mapM_ storeTmpVar vs return $ sn { cost = (cost sn) + n , stmt2 = s :(stmt2 sn) , stmt1 = s':(stmt1 sn) , rest2 = g } -- TODO: Refactor isRet (L _ (Ret _)) = True isRet _ = False needsDummy (L _ (Assign _ _)) = True needsDummy _ = False combinations :: [(LStmt Var, StmtGraph)] -> [(LStmt Var, StmtGraph)] -> [((LStmt Var, StmtGraph),(LStmt Var, StmtGraph))] combinations l1 l2 = [ ((s1, g1), (s2, g2)) | (s1, g1) <- l1 , (s2, g2) <- l2 , sameKind s1 s2 ] -------------------------------------------------------------------------------- -- ** Dependency graphs -------------------------------------------------------------------------------- type LOC = Int type Weight = Int -- a := b; -- b := c; -- r := s; -- z := b + r; -- -- 1 -> (a := b, []) -- 2 -> (b := c, [1]) -- 3 -> (r := s, []) -- 4 -> (z := b + r, [2,3]) -- Statement dependency graph. Array of statements and list of dependencies data StmtGraph = SGraph Weight (M.Map LOC (LStmt Var, [LOC])) instance PP StmtGraph where ppr (SGraph _ m) = vsep $ map (\(l, s) -> ppr l <+> text "->" <+> ppr s) $ M.assocs m -- | Check if dependency graph is null nullG :: StmtGraph -> Bool nullG (SGraph _ m) = M.null m {- Not used but useful in the future. -- | emptyGraph emptyGraph :: StmtGraph emptyGraph = SGraph 0 M.empty -} -- | Create a dependency graph from a basicblock mkStmtGraph :: BasicBlock -> StmtGraph mkStmtGraph ss = SGraph w $! lssDeps where lss = zip [1..] ss {--} -- zip [length ss, length ss -1..1] $ ss (w, lssDeps) = calculateDeps M.empty M.empty lss calculateDeps :: M.Map Var LOC -> M.Map Var LOC -> [(LOC, LStmt Var)] -> (Weight, M.Map LOC (LStmt Var, [LOC])) calculateDeps _ _ [] = (0, M.empty) calculateDeps lvars vars ((loc, stmt):rest) = (w' + stmtCost stmt, mm `seq` M.insert loc (stmt, nub $ deps1 ++ deps2) mm) where lvs = getLVars stmt vs = getVars stmt nlvs = foldl' (\m v -> M.insert v loc m) lvars lvs nvs = foldl' (\m v -> M.insert v loc m) vars vs deps1 = catMaybes $ map (`M.lookup` lvars) $ vs deps2 = catMaybes $ map (`M.lookup` vars) $ lvs (w', mm) = calculateDeps nlvs nvs rest {- Not used but useful in the future takeBlock :: StmtGraph -> (BasicBlock, StmtGraph) takeBlock (SGraph w a) = ng `seq` (stmts, SGraph w' ng) where noDeps = M.filter (null . snd) a stmts = map fst $ M.elems noDeps locs = M.keys noDeps (w',ng) = M.foldWithKey fAdjDeps (w,a) a fAdjDeps :: LOC -> (LStmt Var, [LOC]) -> (Weight, M.Map LOC (LStmt Var, [LOC])) -> (Weight, M.Map LOC (LStmt Var, [LOC])) fAdjDeps k (stmt, deps) (wgt, mp) | k `elem` locs = (wgt - stmtCost stmt, mp `seq` M.delete k mp) | otherwise = (wgt, mp `seq` M.insert k (stmt, deps \\ locs) mp) -} anyStmt :: StmtGraph -> [(LStmt Var, StmtGraph)] anyStmt (SGraph w a) = map fGetAlts ndlst where ndlst = M.assocs $ M.filter (null . snd) a fGetAlts :: (LOC, (LStmt Var, [LOC])) -> (LStmt Var, StmtGraph) fGetAlts (k, (s, _)) = (s, SGraph (w - stmtCost s) $! M.foldWithKey (fAdjDeps k) a a) fAdjDeps :: LOC -> LOC -> (LStmt Var, [LOC]) -> M.Map LOC (LStmt Var, [LOC]) -> M.Map LOC (LStmt Var, [LOC]) fAdjDeps toDel k (stmt, deps) mp | k == toDel = mp `seq` M.delete k mp | otherwise = mp `seq` M.insert k (stmt, filter (/= toDel) deps) mp {- Not used but useful in the future -- | Traverse StmtGraph toStmtList :: StmtGraph -> [LStmt Var] toStmtList g | nullG g = [] | otherwise = s' ++ toStmtList g' where (s', g') = takeBlock g stmtsOf :: StmtGraph -> [LStmt Var] stmtsOf (SGraph _ a) = map fst $ M.elems a -} -------------------------------------------------------------------------------- -- ** Operations -------------------------------------------------------------------------------- ---- | Compare two statement blocks. ---- ---- The result is an integer whose value denotes the cost of introducing the ---- necessary dummy ops to turn both blocks indistinguishable --compareBlocks :: BasicBlock -> BasicBlock -> Int --compareBlocks = undefined -- | Create dummy op mkDummyOp :: CaoMonad m => LStmt Var -> m (Int, Set Var, LStmt Var) mkDummyOp (unLoc -> Assign lvs es) = do (vs' ,lvs') <- unzip <$> mapM mkDummyLv lvs (ns, vs'',es') <- unzip3 <$> mapM mkDummyLExpr es return (sum ns, Set.unions $ vs' ++ vs'', genLoc $ Assign lvs' es') mkDummyOp (unLoc -> FCallS fn es) = do (ns, vs, es') <- unzip3 <$> mapM mkDummyLExpr es return (sum ns, Set.unions vs, genLoc $ FCallS fn es') mkDummyOp s = error $ "Language.CAO.CaoSSA.mkDummyOp: failed to create a dummy\ \operation of this kind!" ++ showPpr s -- mkDummyOp (Ret es) = Ret <$> mapM mkDummyLExpr es -- mkDummyOp (Ite i t me) = -- mkDummyOp (Seq (SeqIter id) [LStmt id] -- mkDummyOp (While e1 ss) -- mkDummyOp (VDecl vd) mkDummyLv :: CaoMonad m => LVal Var -> m (Set Var, LVal Var) mkDummyLv (LVVar (L _ v)) = lvvar <$> freshVar Local (varType v) where lvvar v' = (Set.singleton v', LVVar $ genLoc v') mkDummyLv (LVStruct lv n) = fixT2 (flip LVStruct n) (mkDummyLv lv) mkDummyLv (LVCont t lv p) = fixT2 (flip (LVCont t) p) (mkDummyLv lv) mkDummyLExpr :: CaoMonad m => TLExpr Var -> m (Int, Set Var, TLExpr Var) mkDummyLExpr (L l e) = fixT3 (L l) (mkDummyExpr e) fixT2 :: CaoMonad m => (a -> b) -> m (c, a) -> m (c, b) fixT2 f m = (\(a, b) -> (a, f b)) <$> m fixT3 :: CaoMonad m => (a -> b) -> m (r, s, a) -> m (r, s, b) fixT3 f m = (\(a, b, c) -> (a, b, f c)) <$> m -- TODO: complete with other exprs, fix cost of ops mkDummyExpr :: CaoMonad m => TExpr Var -> m (Int, Set Var, TExpr Var) mkDummyExpr (TyE t e@(BinaryOp (ArithOp op) _ _)) = do e' <- T.mapM (freshVar Local . varType) e return (costAOp op, fvs e', TyE t e') mkDummyExpr e = do e' <- T.mapM (freshVar Local . varType) e return (0 , fvs e', e') -- TODO: Complete!!! {- Not used but useful in the future -- | BasicBlock cost blockCost :: BasicBlock -> Int blockCost = sum . map stmtCost -} -- | Stmt cost stmtCost :: LStmt Var -> Int stmtCost (unLoc -> Assign _ es) = sum $ map costLExpr es stmtCost (unLoc -> FCallS _ es) = sum $ map costLExpr es stmtCost _ = 0 costLExpr :: TLExpr Var -> Int costLExpr (L _ (TyE _ e)) = costExpr e -- TODO: complete with other exprs, fix cost of ops costExpr :: Expr Var -> Int costExpr (BinaryOp (ArithOp op) _ _) = costAOp op costExpr _ = 0 costAOp :: AOp -> Int costAOp Plus = 1 costAOp Minus = 1 costAOp Times = 10 costAOp Div = 10 costAOp ModOp = 10 costAOp Power = 100 -- TODO: create dependency funcs. Place statements with no dependencies. Check -- all possible reorderings with the cost of the necessary dummy instructions -- and pick the lowest. Remove dependencies from graph and continue.