module Camfort.Transformation.DerivedTypeIntro where
import Data.Data
import Data.List hiding (union, insert)
import Data.Maybe
import Data.Set hiding (foldl, map)
import Data.Generics.Uniplate.Operations
import Control.Monad.State.Lazy
import Debug.Trace
import qualified Data.Map as Data.Map
import Language.Fortran
import Camfort.Analysis.Annotations
import Camfort.Analysis.IntermediateReps
import Camfort.Analysis.Syntax
import Camfort.Transformation.Syntax
import Camfort.Analysis.Types
import Camfort.Helpers
import Camfort.Traverse
typeStruct :: [(Filename, Program Annotation)] -> (Report, [(Filename, Program Annotation)])
typeStruct fps = mapM (\(f, ps) -> mapM typeStructPerProgram ps >>= (\ps' -> return (f, ps'))) fps
type Graph v a = [((v, v), a)]
type WeightedEdge v a = ((v, v), (a, Int))
type WeightedGraph v a = [WeightedEdge v a]
vertices = concatMap (\((x, y), _) -> [x, y])
isVertex v wgs = elem v (vertices wgs)
getVertex v [] = Nothing
getVertex v (((v1, v2), d):es) = if v == v1 || v == v2 then Just d
else getVertex v es
typeStructPerProgram :: ProgUnit Annotation -> (Report, ProgUnit Annotation)
typeStructPerProgram p = descendBiM
(\b@(Block a uses implicits span decs blockBody) ->
let
tenv = typeEnv b
es = Exprs `topFrom` b
prjVarsWTarget = map locsFromArrayIndex es
iGraph = toInterferenceGraph prjVarsWTarget
wiGraph = calculateWeights iGraph
wgf = decomposeWeightedGraph wiGraph
tDefsAndNames = evalState (mapM (mkTypeDef tenv (fst span, fst span)) wgf) 0
nwgf = zip wgf (map snd tDefsAndNames)
rAnnotation = if (length tDefsAndNames > 0)
then unitAnnotation { refactored = Just (fst span) }
else unitAnnotation
blockBody' = elimProjectionDefs blockBody iGraph
decs' = foldl (DSeq unitAnnotation) decs (map fst tDefsAndNames)
a' = if (length tDefsAndNames > 0) then a { refactored = Just (fst span) } else a
in
(show wiGraph ++ "\n\n" ++ show wgf, Block a' uses implicits span decs' blockBody')) p
toInterferenceGraph :: [[(Variable, Access)]] -> Graph Access Variable
toInterferenceGraph pvars = let rel = concatMap listToSymmRelation pvars
matchingArrayTargets r ((a, x), (b, y))
| a == b = ((x, y), a) : r
| otherwise = r
in foldl matchingArrayTargets [] rel
listToSymmRelation :: [a] -> [(a, a)]
listToSymmRelation [] = []
listToSymmRelation (x:xs) = ((repeat x) `zip` xs) ++ (listToSymmRelation xs)
correctManualImpl ranges stmt graph =
let (_, pvarmap) = runState (transformBiM collect stmt) Data.Map.empty
in Data.Map.foldWithKey
(\arr vixs p -> case (lookup arr ranges) of
Just (l, u) -> (sort (map snd vixs) == [l..u]) && p) True pvarmap
where
collect :: Fortran A -> State (Data.Map.Map Variable [(Variable, Integer)]) (Fortran A)
collect a@(Assg p sp e1 e2) =
do indexMap <- get
case (do v <- varExprToVariable e1
arr <- getVertex (VarA v) graph
case e2 of
(ConS _ _ val) ->
case (Data.Map.lookup arr indexMap) of
Just ixs ->
case (lookup v ixs) of
Just val' -> Nothing
Nothing -> Just $ Data.Map.update (\ixs -> Just $ ((v, read $ val) : ixs)) arr indexMap
Nothing -> Just $ Data.Map.insert arr [(v, read $ val)] indexMap) of
Just indexMap' -> do put indexMap'; return a
Nothing -> return a
collect f = return f
elimProjectionDefs :: Fortran A -> Graph Access Variable -> Fortran A
elimProjectionDefs stmt graph = transformBi ef stmt
where ef a@(Assg p sp e1 e2) =
case (varExprToVariable e1) of
Just v -> if (isVertex (VarA v) graph) then
NullStmt (p { refactored = Just $ dropLine' sp }) sp
else a
Nothing -> a
ef f = f
arrayAccessToProjection :: Fortran A -> Graph Access Variable -> Fortran A
arrayAccessToProjection = undefined
calculateWeights :: (Eq (AnnotationFree a), Eq (AnnotationFree v), Ord a, Ord v) => Graph v a -> WeightedGraph v a
calculateWeights xs = calcWs (sort xs) 1
where calcWs [] _ = []
calcWs [((v1, v2), a)] n = [((v1, v2), (a, n))]
calcWs (e@((v1, v2), a):(e':es)) n | ((af e == af e') || (af e == (af (swap e'))))
= calcWs (e':es) (n + 1)
| otherwise = ((v1, v2), (a, n)) : (calcWs (e':es) 1)
swap ((a, b), v) = ((b, a), v)
locsFromArrayIndex :: Data t => t -> [(Variable, Access)]
locsFromArrayIndex x =
concat . concat $
each (Vars `from` x)
(\(Var _ _ ves) ->
each ves (\(VarName _ v, ixs) ->
if (not $ all isConstant ixs)
then map (\x -> (v, x)) (Locs `from` ixs)
else []))
findMatch v ix ((wg, n):wgns) = vertices
mkTyDecl :: SrcSpan -> Variable -> Type Annotation -> Decl Annotation
mkTyDecl sp v t = let ua = unitAnnotation
in Decl ua sp [(Var ua sp [(VarName ua v, [])], NullExpr ua sp, Nothing)] t
mkTypeDef :: TypeEnv Annotation -> SrcSpan -> WeightedGraph Access Variable -> State Int (Decl Annotation, String)
mkTypeDef tenv sp wg = (inventName wg) >>= (\name ->
let edgeToDecls ((vx, vy), (va, w)) =
case (lookup va tenv) of
Just t -> [mkTyDecl sp (accessToVarName vx) (arrayElementType t),
mkTyDecl sp (accessToVarName vy) (arrayElementType t)]
Nothing -> error $ "Can't find the type of " ++ show va ++ "\n"
ra = unitAnnotation { refactored = Just (fst sp) }
(_, (arrayVar, _)) = head wg
tdecls = concatMap edgeToDecls wg
typeDecl = DerivedTypeDef ra sp (SubName ra name) [] [] tdecls
typeCons = BaseType ra (DerivedType ra (SubName ra name)) [] (NullExpr ra sp) (NullExpr ra sp)
valDecl = Decl ra sp [(Var ra sp [(VarName ra (arrayVar ++ name), [])] , NullExpr ra sp, Nothing)] typeCons
in return $ (DSeq unitAnnotation typeDecl valDecl, name))
inventName :: WeightedGraph Access Variable -> State Int String
inventName graph = do n <- get
put (n + 1)
let vs = vertices graph
return $ map mode (transpose (map accessToVarName vs)) ++ (show n)
mode :: String -> Char
mode x = let freqs = (map (\x -> (head x, length x))) . group . sort $ x
sortedFreqs = sortBy (\x -> \y -> (snd x) `compare` (snd y)) freqs
max = last sortedFreqs
in
if (snd max) > ((length x) `div` 2) then fst max else 'X'
decomposeWeightedGraph :: forall v a . (Show v, Ord v, Ord a) => WeightedGraph v a -> [WeightedGraph v a]
decomposeWeightedGraph g = map snd (concatMap (foldl binEdge []) (groupBy groupOnArrayVar (sortBy sortOnArrayVar g)))
where groupOnArrayVar (_, (av, _)) (_, (av', _)) = av == av'
sortOnArrayVar (_, (av, _)) (_, (av', _)) = compare av av'
binEdge :: (Show v, Ord v, Ord a) => [(Set v, WeightedGraph v a)] -> WeightedEdge v a -> [(Set v, WeightedGraph v a)]
binEdge bins e@((x, y), _) =
let findBin v [] = ((insert x empty, []), [])
findBin v ((vs, es):bs) | member v vs = ((insert v vs, es), bs)
| otherwise = let (n, bs') = findBin v bs
in (n, (vs, es) : bs')
((vs, es), bins') = findBin x bins
((vs', es'), bins'') = findBin y bins'
in (vs `union` vs', e : (es ++ es')) : bins''