{- Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -} {- Units of measure extension to Fortran Files: Units.hs UnitsEnvironment.hs TODO: * Deal with variable shadowing in "contained" functions. * Better errors with line number info -} {-# LANGUAGE ScopedTypeVariables, ImplicitParams, DoAndIfThenElse #-} module Extensions.Units where import Data.Ratio import Data.Maybe import Data.Matrix import Data.List import Data.Char (isNumber) import qualified Data.Vector as V import Data.Label.Mono (Lens) import qualified Data.Label import Data.Label.Monadic hiding (modify) import Data.Function import Data.Data import Data.Char import Control.Monad.State.Strict hiding (gets) import Data.Generics.Uniplate.Operations import Helpers import Output import Analysis.Annotations import Analysis.Syntax import Analysis.Types import Extensions.UnitsEnvironment -- Provides the types and data accessors used in this module import Extensions.UnitsSolve -- Solvers for the Gaussian matrix import Language.Fortran import Language.Fortran.Pretty import Transformation.Syntax -- For debugging and development purposes import qualified Debug.Trace as D {- HELPERS -} -- Update a list state by consing infix 2 << (<<) :: MonadState f m => Lens (->) f [o] -> o -> m () (<<) lens o = lens =. (o:) -- Update a list state by appending infix 2 <<++ (<<++) lens o = lens =. (++ [o]) {- START HERE! Two main functions of this file: inferUnits and removeUnits -} removeUnits :: (Filename, Program Annotation) -> (Report, (Filename, Program Annotation)) removeUnits (fname, x) = let ?criticals = False in ("", (fname, map (descendBi removeUnitsInBlock) x)) -- ************************************* -- Unit inference (top - level) -- ************************************* inferCriticalVariables :: (?solver :: Solver, ?assumeLiterals :: AssumeLiterals) => (Filename, Program Annotation) -> (Report, (Filename, Program Annotation)) inferCriticalVariables (fname, x) = let ?criticals = True ?debug = False in let infer = do doInferUnits x vars <- criticalVars case vars of [] -> do report <<++ "No critical variables. Appropriate annotations." _ -> do report <<++ "Critical variables: " ++ (concat $ intersperse "," vars) ifDebug debugGaussian (_, env) = runState infer emptyUnitEnv r = concat [fname ++ ": " ++ r ++ "\n" | r <- Data.Label.get report env] in (r, (fname, x)) inferUnits :: (?solver :: Solver, ?assumeLiterals :: AssumeLiterals) => (Filename, Program Annotation) -> (Report, (Filename, Program Annotation)) inferUnits (fname, x) = let ?criticals = False ?debug = False in let (y, env) = runState (doInferUnits x) emptyUnitEnv r = concat [fname ++ ": " ++ r ++ "\n" | r <- Data.Label.get report env] ++ fname ++ ": checked/inferred " ++ (show $ countVariables (_varColEnv env) (_debugInfo env) (_procedureEnv env) (fst $ _linearSystem env) (_unitVarCats env)) ++ " user variables\n" in (r, (fname, y)) countVariables vars debugs procs matrix ucats = length $ filter (\c -> case (ucats !! (c - 1)) of Variable -> case (lookupVarsByCols vars [c]) of [] -> False _ -> True Argument -> case (lookupVarsByCols vars [c]) of [] -> False _ -> True _ -> False) [1..ncols matrix] doInferUnits :: (?criticals :: Bool, ?solver :: Solver, ?debug :: Bool, ?assumeLiterals :: AssumeLiterals) => Program Annotation -> State UnitEnv (Program Annotation) doInferUnits x = do mapM inferProgUnits x ifDebug (report <<++ "Finished inferring prog units") ifDebug debugGaussian inferInterproceduralUnits x succeeded <- gets success p <- if (?criticals || (not succeeded)) then return x -- don't insert unit annotations else mapM (descendBiM insertUnitsInBlock) x (n, added) <- gets evUnitsAdded if (?criticals || (not succeeded)) then return () else report <<++ ("Added " ++ (show n) ++ " non-unitless annotation: " ++ (concat $ intersperse "," $ added)) return p inferProgUnits :: (?criticals :: Bool, ?solver :: Solver, ?debug :: Bool, ?assumeLiterals :: AssumeLiterals) => ProgUnit Annotation -> State UnitEnv () inferProgUnits p = do -- infer units for *root* program unit inferPUnit p -- infer units for the *children* program units (so that the parent scope is processed first) mapM_ inferProgUnits $ ((children p)::[ProgUnit Annotation]) where -- Infer the units for the *root* program unit (not children) inferPUnit :: ProgUnit Annotation -> State UnitEnv () inferPUnit (Main x sp n a b _) = inferBlockUnits b Nothing inferPUnit (Sub x sp t (SubName _ n) (Arg _ a _) b) = inferBlockUnits b (Just (n, Nothing, argNames a)) inferPUnit (Function _ _ _ (SubName _ n) (Arg _ a _) r b) = inferBlockUnits b (Just (n, Just (resultName n r), argNames a)) inferPUnit (Module x sp (SubName _ n) _ _ d ds) = transformBiM (inferDecl (Just (n, Nothing, []))) d >> return () inferPUnit x = return () argNames :: ArgName a -> [Variable] argNames (ArgName _ n) = [n] argNames (ASeq _ n1 n2) = argNames n1 ++ argNames n2 argNames (NullArg _) = [] resultName :: Variable -> Maybe (VarName a) -> Variable resultName n Nothing = n resultName _ (Just (VarName _ r)) = r inferBlockUnits :: (?solver :: Solver, ?criticals :: Bool, ?debug :: Bool, ?assumeLiterals :: AssumeLiterals) => Block Annotation -> Maybe ProcedureNames -> State UnitEnv () inferBlockUnits x proc = do resetTemps enterDecls x proc addProcedure proc descendBiM handleStmt x case proc of Just _ -> {- not ?criticals -> -} do -- Intermediate solve for procedures (subroutines & functions) ifDebug (report <<++ "Pre doing row reduce") consistent <- solveSystemM "" success =: consistent linearSystem =. reduceRows 1 ifDebug (report <<++ "Post doing reduce") ifDebug (debugGaussian) return () _ -> return () -- return x where handleStmt :: Fortran Annotation -> State UnitEnv (Fortran Annotation) handleStmt x = do inferStmtUnits x return x {-| reduceRows is a core part of the polymorphic unit checking for procedures. It is essentially an "optimiation" of the Gaussian matrix (not in the sense of performance), that elimiantes rows in the system such that there are as few variables as possible. Within a function, assuming everything is consistent, then this should generate a linear constraint between the parameters and the return as a single row in the matrix. This is then used by the interprocedural constraints to hookup call-sites with definitions (in a parametrically polymorphic way- i.e. lambda abstraction is polymorphic in its units, different to say ML). -} reduceRows :: Col -> LinearSystem -> LinearSystem reduceRows m (matrix, vector) | m > ncols matrix = (matrix, vector) | otherwise = case (find (\n -> matrix ! (n, m) /= 0) [1..nrows matrix]) of Just r1 -> case (find (\n -> matrix ! (n, m) /= 0) [(r1 + 1)..nrows matrix]) of Just r2 -> -- Found two rows with non-zero coeffecicients in this column case (elimRow (matrix, vector) (Just r1) m r2) of -- Eliminate the row and cut the system down Ok (matrix', vector') -> reduceRows m (cutSystem r2 (matrix', vector')) Bad _ _ _ -> reduceRows (m+1) (matrix, vector) Nothing -> -- If there are no two rows with non-zero coeffecieints in colum m -- then move onto the next column reduceRows (m+1) (matrix, vector) Nothing -> reduceRows (m+1) (matrix, vector) addProcedure :: Maybe ProcedureNames -> State UnitEnv () addProcedure Nothing = return () addProcedure (Just (name, resultName, argNames)) = do uenv <- gets varColEnv resultVar <- case resultName of Just rname -> case (lookupWithoutSrcSpan rname uenv) of Just (uvar, _) -> return $ Just uvar Nothing -> do m <- addCol Variable return $ Just (VarCol m) Nothing -> return Nothing let argVars = fmap (lookupUnitByName uenv) argNames procedureEnv << (name, (resultVar, argVars)) where lookupUnitByName uenv v = maybe (VarCol 1) fst $ lookupWithoutSrcSpan v uenv -- *************************************** -- -- * Unit inference (main, over all AST) -- -- *************************************** enterDecls :: (?assumeLiterals :: AssumeLiterals) => Block Annotation -> Maybe ProcedureNames -> State UnitEnv (Block Annotation) enterDecls x proc = transformBiM (inferDecl proc) x processVar :: (?assumeLiterals :: AssumeLiterals) => [UnitConstant] -> Maybe ProcedureNames -> (Expr Annotation, Expr Annotation) -> Type Annotation -> State UnitEnv (Expr Annotation, Expr Annotation) processVar units proc exps@(Var a s names, e) typ = do let (VarName _ v, es) = head names system <- gets linearSystem let m = ncols (fst system) + 1 unitVarCats <<++ (unitVarCat v proc) extendConstraints units ms <- case toArrayType typ es of ArrayT _ bounds _ _ _ _ -> mapM (const $ fmap VarCol $ addCol Variable) bounds _ -> return [] varColEnv << (VarBinder (v, s), (VarCol m, ms)) uv <- gets varColEnv -- If the declaration has a null expression, do not create a unifying variable case e of NullExpr _ _ -> return () _ -> do uv <- inferExprUnits e mustEqual False (VarCol m) uv return () return exps unitVarCat :: Variable -> Maybe ProcedureNames -> UnitVarCategory unitVarCat v proc | Just (n, r, args) <- proc, v `elem` args = Argument | otherwise = Variable {-| inferDecl - extract and record information from explicit unit declarations -} inferDecl :: (?assumeLiterals :: AssumeLiterals) => Maybe ProcedureNames -> Decl Annotation -> State UnitEnv (Decl Annotation) inferDecl proc decl@(Decl a s d typ) = do let BaseType _ _ attrs _ _ = arrayElementType typ units <- sequence $ concatMap extractUnit attrs mapM_ (\(e1, e2, multiplier) -> processVar units proc (e1, e2) typ) d return $ decl inferDecl proc x@(MeasureUnitDef a s d) = do mapM learnDerivedUnit d return x where learnDerivedUnit (name, spec) = do denv <- gets derivedUnitEnv when (isJust $ lookup name denv) $ error "Redeclared unit of measure" unit <- convertUnit spec denv <- gets derivedUnitEnv when (isJust $ lookup name denv) $ error "Recursive unit-of-measure definition" derivedUnitEnv << (name, unit) inferDecl _ x = return x extendConstraints :: [UnitConstant] -> State UnitEnv () extendConstraints units = do (matrix, vector) <- gets linearSystem let n = nrows matrix + 1 m = ncols matrix + 1 linearSystem =: case units of [] -> do (extendTo 0 0 m matrix, vector) _ -> (setElem 1 (n, m) $ extendTo 0 n m matrix, vector ++ [last units]) tmpColsAdded << m tmpRowsAdded << n return () inferInterproceduralUnits :: (?solver :: Solver, ?criticals :: Bool, ?debug :: Bool, ?assumeLiterals :: AssumeLiterals) => Program Annotation -> State UnitEnv () inferInterproceduralUnits x = do --reorderColumns if ?criticals then reorderVarCols else return () consistent <- solveSystemM "inconsistent" if consistent then do system <- gets linearSystem let dontAssumeLiterals = case ?assumeLiterals of Poly -> True Unitless -> False Mixed -> False inferInterproceduralUnits' x dontAssumeLiterals system -- edited return () else return () inferInterproceduralUnits' :: (?solver :: Solver, ?criticals :: Bool, ?debug :: Bool) => Program Annotation -> Bool -> LinearSystem -> State UnitEnv (Program Annotation) inferInterproceduralUnits' x haveAssumedLiterals system1 = do addInterproceduralConstraints x consistent <- solveSystemM "inconsistent" if not consistent then do linearSystem =: system1 return x else do system2 <- gets linearSystem if system1 == system2 then if ?criticals then nextStep else checkUnderdeterminedM >> nextStep else inferInterproceduralUnits' x haveAssumedLiterals system2 where nextStep | haveAssumedLiterals = return x | otherwise = do consistent <- assumeLiteralUnits if not consistent then return x else do system3 <- gets linearSystem inferInterproceduralUnits' x True system3 class UpdateColInfo t where updateColInfo :: Col -> Col -> t -> t instance UpdateColInfo VarCol where updateColInfo x n (VarCol y) | y == x = VarCol n | y == n = VarCol x | otherwise = VarCol y instance UpdateColInfo VarColEnv where updateColInfo _ _ [] = [] updateColInfo x n ((v, (uv, uvs)):ys) = (v, (updateColInfo x n uv, map (updateColInfo x n) uvs)) : (updateColInfo x n ys) instance UpdateColInfo Procedure where updateColInfo x n (Nothing, ps) = (Nothing, map (updateColInfo x n) ps) updateColInfo x n (Just p, ps) = (Just $ updateColInfo x n p, map (updateColInfo x n) ps) instance UpdateColInfo ProcedureEnv where updateColInfo x n = map (\(s, p) -> (s, updateColInfo x n p)) instance UpdateColInfo (Int, a) where updateColInfo x n (y, s) | y == x = (n, s) | y == n = (x, s) | otherwise = (y, s) instance UpdateColInfo Int where updateColInfo x n y | y == x = x | y == n = n | otherwise = y swapUnitVarCats x n xs = swapUnitVarCats' x n xs xs 1 swapUnitVarCats' x n [] ys c = [] swapUnitVarCats' x n (z:zs) ys c | c == x = (ys !! (n - 1)) : (swapUnitVarCats' x n zs ys (c + 1)) | c == n = (ys !! (x - 1)) : (swapUnitVarCats' x n zs ys (c + 1)) | otherwise = z : (swapUnitVarCats' x n zs ys (c + 1)) swapCols :: Int -> Int -> State UnitEnv () swapCols x n = do --report <<++ ("Pre swap - " ++ (show x) ++ " <-> " ++ (show n)) --debugGaussian varColEnv =. updateColInfo x n procedureEnv =. updateColInfo x n calls =. updateColInfo x n unitVarCats =. swapUnitVarCats x n linearSystem =. (\(m, v) -> (switchCols x n m, v)) debugInfo =. map (updateColInfo x n) tmpColsAdded =. map (updateColInfo x n) --report <<++ "Post swap" --debugGaussian return () {-| reorderVarCols puts any variable columns to the end of the Gaussian matrix (along with the associated information) -} reorderVarCols :: State UnitEnv () reorderVarCols = do ucats <- gets unitVarCats (matrix, _) <- gets linearSystem reorderVarCols' (ncols matrix) 1 where correctEnd :: Int -> State UnitEnv Int correctEnd 0 = return 0 correctEnd end = do ucats <- gets unitVarCats case (ucats !! (end - 1)) of Variable -> correctEnd (end - 1) _ -> return $ end reorderVarCols' :: Int -> Int -> State UnitEnv () reorderVarCols' end c | c >= end = return () reorderVarCols' end c = do ucats <- gets unitVarCats case (ucats !! (c - 1)) of Variable -> do end' <- correctEnd end swapCols end' c reorderVarCols' (end' - 1) (c+1) _ -> reorderVarCols' end (c+1) assumeLiteralUnits :: (?solver :: Solver, ?debug :: Bool) => State UnitEnv Bool assumeLiteralUnits = do system@(matrix, vector) <- gets linearSystem mapM_ assumeLiteralUnits' [1 .. ncols matrix] consistent <- solveSystemM "underdetermined" when (not consistent) $ linearSystem =: system return consistent assumeLiteralUnits' m = do (matrix, vector) <- gets linearSystem ucats <- gets unitVarCats let n = find (\n -> matrix ! (n, m) /= 0) [1 .. nrows matrix] m' = n >>= (\n -> find (\m -> matrix ! (n, m) /= 0) [1 .. ncols matrix]) nonLiteral n m = matrix ! (n, m) /= 0 && ucats !! (m - 1) /= (Literal True) m's = n >>= (\n -> find (nonLiteral n) [1 .. ncols matrix]) when (ucats !! (m - 1) == (Literal True) && (m' /= Just m || isJust m's)) $ do n' <- addRow modify $ liftUnitEnv $ setElem 1 (n', m) addInterproceduralConstraints :: (?debug :: Bool) => Program Annotation -> State UnitEnv () addInterproceduralConstraints x = do cs <- gets calls mapM_ addCall cs where addCall (name, (result, args)) = do penv <- gets procedureEnv case lookup name penv of Just (r, as) -> let (r1, r2) = decodeResult result r in handleArgs (args ++ r1) (as ++ r2) Nothing -> return () handleArgs actualVars dummyVars = do order <- gets reorderedCols let actual = map (\(VarCol uv) -> uv) actualVars dummy = map (\(VarCol uv) -> uv) dummyVars mapM_ (handleArg $ zip dummy actual) dummy -- experimentation but now deprecated. {- handleArgNew dummyToActual dummy = do grid0 <- debugGaussian' mapM (\(l, r) -> do n <- addRow modify $ liftUnitEnv $ setElem 1 (n, l) modify $ liftUnitEnv $ setElem (-1) (n, r) ) dummyToActual grid1 <- debugGaussian' if (grid0 == grid1) then return () else do report <<++ "HANDLED AND DIFFERENT!" report <<++ ("\n" ++ grid0) report <<++ ("\n" ++ grid1) return ()-} -- TODO: this can be optimised handleArg dummyToActual dummy = do (matrix, vector) <- gets linearSystem --grid0 <- debugGaussian' ifDebug (debugGaussian) ifDebug (report <<++ ("hArg - " ++ show dummyToActual ++ "-" ++ show dummy)) let -- find the first row with a non-zero column for the variable n = maybe 1 id $ find (\n -> matrix ! (n, dummy) /= 0) [1 .. nrows matrix] -- find the first non-zero column on the row just selected Just m = find (\m -> matrix ! (n, m) /= 0) [1 .. ncols matrix] ifDebug (report <<++ ("n = " ++ show n ++ ", m = " ++ show m)) if (m == dummy) then do let -- Get list of columns with non-zero coefficients to the right of the focus ms = filter (\m -> matrix ! (n, m) /= 0) [m .. ncols matrix] -- Get the list of columns to which the non-zero coeffecients are paired by 'dummyToActual' relation. m's = mapMaybe (flip lookup dummyToActual) ms pairs = --if (length m's == 1) then -- i.e. there is not a direct relationship between variable and return -- zip ms (repeat (head m's)) --else (zip ms m's) ifDebug(report <<++ ("ms = " ++ show ms ++ ", m's' = " ++ show m's ++ ", their zip = " ++ show pairs ++ " dA = " ++ show dummyToActual)) if (True) -- length m's == length ms) then do { newRow <- addRow' $ vector !! (n - 1); -- mapM_ (handleArgPair matrix n newRow) pairs ; } mapM_ (handleArgPair matrix n newRow) dummyToActual ; } else return () else return () -- Copy the row handleArgPair matrix n newRow (m, m') = do modify $ liftUnitEnv $ setElem (matrix ! (n, m)) (newRow, m') decodeResult (Just r1) (Just r2) = ([r1], [r2]) decodeResult Nothing Nothing = ([], []) decodeResult (Just _) Nothing = error "Subroutine used as a function!" decodeResult Nothing (Just _) = error "Function used as a subroutine!" inferLiteral e = do uv@(VarCol uvn) <- anyUnits (Literal (?assumeLiterals /= Mixed)) debugInfo << (uvn, (srcSpan e, pprint e)) return uv data BinOpKind = AddOp | MulOp | DivOp | PowerOp | LogicOp | RelOp binOpKind :: BinOp a -> BinOpKind binOpKind (Plus _) = AddOp binOpKind (Minus _) = AddOp binOpKind (Mul _) = MulOp binOpKind (Div _) = DivOp binOpKind (Or _) = LogicOp binOpKind (And _) = LogicOp binOpKind (Concat _)= AddOp binOpKind (Power _) = PowerOp binOpKind (RelEQ _) = RelOp binOpKind (RelNE _) = RelOp binOpKind (RelLT _) = RelOp binOpKind (RelLE _) = RelOp binOpKind (RelGT _) = RelOp binOpKind (RelGE _) = RelOp (<**>) :: Maybe a -> Maybe a -> Maybe a Nothing <**> x = x (Just x) <**> y = (Just x) lookupCaseInsensitive :: String -> [(String, a)] -> Maybe a lookupCaseInsensitive x m = let x' = map toUpper x in (find (\(k, v) -> (map toUpper k) == x') m) >>= (return . snd) lookupWithoutSrcSpan :: Variable -> [(VarBinder, a)] -> Maybe a lookupWithoutSrcSpan v env = snd `fmap` find f env where f (VarBinder (w, _), _) = map toUpper w == v' v' = map toUpper v lookupWithSrcSpan :: Variable -> SrcSpan -> [(VarBinder, a)] -> Maybe a lookupWithSrcSpan v s env = snd `fmap` find f env where f (VarBinder (w, t), _) = map toUpper w == v' && s == t v' = map toUpper v inferExprUnits :: (?assumeLiterals :: AssumeLiterals) => Expr Annotation -> State UnitEnv VarCol inferExprUnits e@(Con _ _ _) = inferLiteral e inferExprUnits e@(ConL _ _ _ _) = inferLiteral e inferExprUnits e@(ConS _ _ _) = inferLiteral e inferExprUnits ve@(Var _ _ names) = do uenv <- gets varColEnv penv <- gets procedureEnv let (VarName _ v, args) = head names case lookupWithoutSrcSpan v uenv of -- array variable? Just (uv, uvs@(_:_)) -> inferArgUnits' uvs >> return uv -- function call? Nothing | not (null args) -> do case (lookup (map toUpper v) intrinsicsDict) of Just fun -> fun v Nothing -> return () -- error $ "I don't know the intrinsic " ++ v -- return () uv@(VarCol uvn) <- anyUnits Temporary debugInfo << (uvn, (srcSpan ve, pprint ve)) uvs <- inferArgUnits let uvs' = justArgUnits args uvs calls << (v, (Just uv, uvs')) return uv -- scalar variable or external function call? Just (uv, []) -> inferArgUnits >> return uv -- default specifier _ | v == "*" -> inferLiteral ve -- just bad code x -> case lookupCaseInsensitive v penv of Just (Just uv, argUnits) -> if (null args) then inferArgUnits' argUnits >> return uv else do uv <- anyUnits Temporary uvs <- inferArgUnits let uvs' = justArgUnits args uvs calls << (v, (Just uv, uvs')) return uv Nothing -> error $ "\n" ++ (showSrcFile . srcSpan $ ve) ++ ": undefined variable " ++ v ++ " at " ++ (showSrcSpan . srcSpan $ ve) where inferArgUnits = sequence [mapM inferExprUnits exprs | (_, exprs) <- names, not (nullExpr exprs)] inferArgUnits' uvs = sequence [(inferExprUnits expr) >>= (\uv' -> mustEqual True uv' uv) | ((_, exprs), uv) <- zip names uvs, expr <- exprs, not (nullExpr [expr])] nullExpr [] = False nullExpr [NullExpr _ _] = True nullExpr ((NullExpr _ _):xs) = nullExpr xs nullExpr _ = False justArgUnits [NullExpr _ _] _ = [] -- zero-argument function call justArgUnits _ uvs = head uvs inferExprUnits e@(Bin _ _ op e1 e2) = do uv1 <- inferExprUnits e1 uv2 <- inferExprUnits e2 (VarCol n) <- case binOpKind op of AddOp -> mustEqual True uv1 uv2 MulOp -> mustAddUp uv1 uv2 1 1 DivOp -> mustAddUp uv1 uv2 1 (-1) PowerOp -> powerUnits uv1 e2 LogicOp -> mustEqual True uv1 uv2 RelOp -> do mustEqual True uv1 uv2 return $ VarCol 1 debugInfo << (n, (srcSpan e, pprint e)) return (VarCol n) inferExprUnits (Unary _ _ _ e) = inferExprUnits e inferExprUnits (CallExpr _ _ e1 (ArgList _ e2)) = do uv <- anyUnits Temporary inferExprUnits e1 inferExprUnits e2 error "CallExpr not implemented" return uv -- inferExprUnits (NullExpr .... Shouldn't occur very often as adds unnnecessary cruft inferExprUnits (NullExpr _ _) = anyUnits Temporary inferExprUnits (Null _ _) = return $ VarCol 1 inferExprUnits (ESeq _ _ e1 e2) = do inferExprUnits e1 inferExprUnits e2 return $ error "ESeq units wanted" inferExprUnits (Bound _ _ e1 e2) = do uv1 <- inferExprUnits e1 uv2 <- inferExprUnits e2 mustEqual False uv1 uv2 inferExprUnits (Sqrt _ _ e) = do uv <- inferExprUnits e sqrtUnits uv inferExprUnits (ArrayCon _ _ (e:exprs)) = do uv <- inferExprUnits e mapM_ (\e' -> do { uv' <- inferExprUnits e'; mustEqual True uv uv'}) exprs return uv inferExprUnits (AssgExpr _ _ _ e) = inferExprUnits e inferExprSeqUnits :: (?assumeLiterals :: AssumeLiterals) => Expr Annotation -> State UnitEnv [VarCol] inferExprSeqUnits (ESeq _ _ e1 e2) = liftM2 (++) (inferExprSeqUnits e1) (inferExprSeqUnits e2) inferExprSeqUnits e = (:[]) `liftM` inferExprUnits e handleExpr :: (?assumeLiterals :: AssumeLiterals) => Expr Annotation -> State UnitEnv (Expr Annotation) handleExpr x = do inferExprUnits x return x inferForHeaderUnits :: (?assumeLiterals :: AssumeLiterals) => (Variable, Expr Annotation, Expr Annotation, Expr Annotation) -> State UnitEnv () inferForHeaderUnits (v, e1, e2, e3) = do uenv <- gets varColEnv case (lookupWithoutSrcSpan v uenv) of Just (uv, []) -> do uv1 <- inferExprUnits e1 mustEqual True uv uv1 uv2 <- inferExprUnits e2 mustEqual True uv uv2 uv3 <- inferExprUnits e3 mustEqual True uv uv3 return () Nothing -> report <<++ "Ill-formed Fortran code. Variable '" ++ v ++ "' is not declared." inferSpecUnits :: (?assumeLiterals :: AssumeLiterals) => [Spec Annotation] -> State UnitEnv () inferSpecUnits = mapM_ $ descendBiM handleExpr {-| inferStmtUnits, does what it says on the tin -} inferStmtUnits :: (?assumeLiterals :: AssumeLiterals) => Fortran Annotation -> State UnitEnv () inferStmtUnits e@(Assg _ _ e1 e2) = do uv1 <- inferExprUnits e1 uv2 <- inferExprUnits e2 mustEqual False uv1 uv2 return () inferStmtUnits (DoWhile _ _ _ f) = inferStmtUnits f inferStmtUnits (For _ _ _ (NullExpr _ _) _ _ s) = inferStmtUnits s inferStmtUnits (For _ _ (VarName _ v) e1 e2 e3 s) = do inferForHeaderUnits (v, e1, e2, e3) inferStmtUnits s inferStmtUnits (FSeq _ _ s1 s2) = mapM_ inferStmtUnits [s1, s2] inferStmtUnits (If _ _ e1 s1 elseifs ms2) = do inferExprUnits e1 inferStmtUnits s1 sequence [inferExprUnits e >> inferStmtUnits s | (e, s) <- elseifs] case ms2 of Just s2 -> inferStmtUnits s2 Nothing -> return () inferStmtUnits (Allocate _ _ e1 e2) = mapM_ inferExprUnits [e1, e2] inferStmtUnits (Backspace _ _ specs) = inferSpecUnits specs inferStmtUnits (Call _ _ (Var _ _ [(VarName _ v, [])]) (ArgList _ e2)) = do uvs <- case e2 of NullExpr _ _ -> return [] _ -> inferExprSeqUnits e2 calls << (v, (Nothing, uvs)) inferStmtUnits (Call _ _ e1 (ArgList _ e2)) = mapM_ inferExprUnits [e1, e2] inferStmtUnits (Open _ _ specs) = inferSpecUnits specs inferStmtUnits (Close _ _ specs) = inferSpecUnits specs inferStmtUnits (Continue _ _) = return () inferStmtUnits (Cycle _ _ _) = return () inferStmtUnits (Deallocate _ _ exprs e) = do mapM_ inferExprUnits exprs inferExprUnits e return () inferStmtUnits (Endfile _ _ specs) = inferSpecUnits specs inferStmtUnits (Exit _ _ _) = return () inferStmtUnits (Forall _ _ (header, e) s) = do mapM_ inferForHeaderUnits header inferExprUnits e inferStmtUnits s inferStmtUnits (Goto _ _ _) = return () inferStmtUnits (Nullify _ _ exprs) = mapM_ inferExprUnits exprs inferStmtUnits (Inquire _ _ specs exprs) = do inferSpecUnits specs mapM_ inferExprUnits exprs inferStmtUnits (Rewind _ _ specs) = inferSpecUnits specs inferStmtUnits (Stop _ _ e) = do inferExprUnits e return () inferStmtUnits (Where _ _ e s s') = do inferExprUnits e inferStmtUnits s case s' of Nothing -> return () Just s' -> inferStmtUnits s' inferStmtUnits (Write _ _ specs exprs) = do inferSpecUnits specs mapM_ inferExprUnits exprs inferStmtUnits (PointerAssg _ _ e1 e2) = do uv1 <- inferExprUnits e1 uv2 <- inferExprUnits e2 mustEqual False uv1 uv2 return () inferStmtUnits (Return _ _ e) = do inferExprUnits e return () inferStmtUnits (Label _ _ _ s) = inferStmtUnits s inferStmtUnits (Print _ _ e exprs) = mapM_ inferExprUnits (e:exprs) inferStmtUnits (ReadS _ _ specs exprs) = do inferSpecUnits specs mapM_ inferExprUnits exprs inferStmtUnits (TextStmt _ _ _) = return () inferStmtUnits (NullStmt _ _) = return () -- ************************************* -- Matrix operations -- -- ************************************* inverse :: [Int] -> [Int] inverse perm = [j + 1 | Just j <- map (flip elemIndex perm) [1 .. length perm]] fixValue :: Eq a => (a -> a) -> a -> a fixValue f x = snd $ until (uncurry (==)) (\(x, y) -> (y, f y)) (x, f x) -- The indexing for switchScaleElems and moveElem is 1-based, in line with Data.Matrix. moveElem :: Int -> Int -> [a] -> [a] moveElem i j [] = [] moveElem i j xs | i > j = moveElem j i xs | otherwise = moveElemA i j xs Nothing where moveElemA i j [] (Just z) = [z] moveElemA i j [] Nothing = [] moveElemA 1 j (x:xs) (Just z) = x : moveElemA 1 (j - 1) xs (Just z) moveElemA 1 j (x:xs) Nothing = moveElemA 1 j xs (Just x) moveElemA i j (x:xs) Nothing = x : moveElemA (i - 1) j xs Nothing incrElem :: Num a => a -> (Int, Int) -> Matrix a -> Matrix a incrElem value pos matrix = setElem (matrix ! pos + value) pos matrix moveCol :: Int -> Int -> Matrix a -> Matrix a moveCol i j m | i > j = moveCol j i m | otherwise = matrix (nrows m) (ncols m) $ \(r, c) -> if (c < i || c > j) then m ! (r, c) else if (c >= i && c < j) then m ! (r, c+1) else m ! (r, i) addCol :: UnitVarCategory -> State UnitEnv Int addCol category = do (matrix, vector) <- gets linearSystem let m = ncols matrix + 1 linearSystem =: (extendTo 0 0 m matrix, vector) unitVarCats <<++ category tmpColsAdded << m return m addRow :: State UnitEnv Int addRow = addRow' (Unitful []) addRow' :: UnitConstant -> State UnitEnv Int addRow' uc = do (matrix, vector) <- gets linearSystem let n = nrows matrix + 1 linearSystem =: (extendTo 0 n 0 matrix, vector ++ [uc]) tmpRowsAdded << n return n liftUnitEnv :: (Matrix Rational -> Matrix Rational) -> UnitEnv -> UnitEnv liftUnitEnv f = Data.Label.modify linearSystem $ \(matrix, vector) -> (f matrix, vector) -- ************************************* -- Unit inferences (Helpers) -- -- ************************************* -- mustEqual - used for saying that two units must be the same- returns one of the variables -- (choice doesn't matter, but left is chosen). -- Returns the unit variables equaled upon mustEqual :: (?assumeLiterals :: AssumeLiterals) => Bool -> VarCol -> VarCol -> State UnitEnv VarCol mustEqual flagAsUnitlessIfLit (VarCol uv1) (VarCol uv2) = do n <- addRow modify $ liftUnitEnv $ incrElem (-1) (n, uv1) . incrElem 1 (n, uv2) ucats <- gets unitVarCats if flagAsUnitlessIfLit then case ?assumeLiterals of Mixed -> unitVarCats =: (map (\(n, cat) -> if ((n == uv1 || n == uv2) && ((cat == Literal True) || (cat == Literal False))) then Literal True else cat) (zip [1..] ucats)) _ -> return () else return () return $ VarCol uv1 -- mustAddUp - used for multipling and dividing. Creates a new 'temporary' column and returns -- the variable associated with it mustAddUp :: VarCol -> VarCol -> Rational -> Rational -> State UnitEnv VarCol mustAddUp (VarCol uv1) (VarCol uv2) k1 k2 = do m <- addCol Temporary n <- addRow modify $ liftUnitEnv $ incrElem (-1) (n, m) . incrElem k1 (n, uv1) . incrElem k2 (n, uv2) return $ VarCol m -- TODO: error handling in powerUnits powerUnits :: (?assumeLiterals :: AssumeLiterals) => VarCol -> Expr Annotation -> State UnitEnv VarCol powerUnits (VarCol uv) (Con _ _ powerString) = case fmap (fromInteger . fst) $ listToMaybe $ reads powerString of Just power -> do m <- addCol Temporary n <- addRow modify $ liftUnitEnv $ incrElem (-1) (n, m) . incrElem power (n, uv) return $ VarCol m Nothing -> mustEqual False (VarCol uv) (VarCol 1) powerUnits uv e = do mustEqual False uv (VarCol 1) uv <- inferExprUnits e mustEqual False uv (VarCol 1) sqrtUnits :: VarCol -> State UnitEnv VarCol sqrtUnits (VarCol uv) = do m <- addCol Temporary n <- addRow modify $ liftUnitEnv $ incrElem (-1) (n, m) . incrElem 0.5 (n, uv) return $ VarCol m anyUnits :: UnitVarCategory -> State UnitEnv VarCol anyUnits category = do m <- addCol category return $ VarCol m -- ************************************* -- Gaussian Elimination (Main) -- -- ************************************* {-| Print debug information for non-zero coefficients from the Gaussian matrix -} debugInfoForNonZeros :: [Rational] -> State UnitEnv String debugInfoForNonZeros row = do debugs <- gets debugInfo let cSpots = concatMap (getInfo debugs) (zip [1..] row) return $ if (cSpots == []) then "" else (" arising from \n" ++ cSpots) where getInfo debugs (n, 0) = "" getInfo debugs (n, r) = case lookup n debugs of (Just (span, s)) -> "\t" ++ (showSrcSpan span) ++ " - " ++ s ++ "\n" _ -> "" {- | An attempt at getting some useful user information. Needs position information -} errorMessage :: (?debug :: Bool) => Row -> UnitConstant -> [Rational] -> State UnitEnv String errorMessage row unit rowCoeffs = let ?num = 0 in do uvarEnv <- gets varColEnv debugs <- gets debugInfo u <- makeUnitSpec unit let unitStr = pprint u let varCols = map (+1) (findIndices (\n -> n /= 0) rowCoeffs) if varCols == [] then case unit of Unitful xs | length xs > 1 -> do let xs' = map (\(v, r) -> (v, r * (-1))) (tail xs) uR <- makeUnitSpec (Unitful $ xs') uL <- makeUnitSpec (Unitful [head xs]) success =: False conflictInfo <- debugInfoForNonZeros rowCoeffs return $ let unitStrL = pprint uL unitStrR = pprint uR msg = "Conflict since " ++ unitStrL ++ " != " ++ unitStrR in msg ++ conflictInfo {- A single unit with no variable column suggests an attempt to unify an unit with unitless -} Unitful xs | length xs == 1 -> do let xs' = map (\(v, r) -> (v, r * (-1))) xs uL <- makeUnitSpec (Unitful xs') let unitStrL = pprint uL ifDebug debugGaussian conflictInfo <- debugInfoForNonZeros rowCoeffs return $ "Conflict since " ++ unitStrL ++ " != 1" ++ conflictInfo _ -> do debugGaussian return "Sorry, I can't give a better error." else let varColsAndNames = zip varCols (lookupVarsByCols uvarEnv varCols) exprStr' = map (\(k,v) -> if (rowCoeffs !! (k - 1)) == 1 then v else (showRational (rowCoeffs !! (k - 1))) ++ "*" ++ v) varColsAndNames exprStr = concat $ intersperse "*" exprStr' msg = "Conflict arising from " ++ exprStr ++ " of unit " ++ unitStr in do conflictInfo <- debugInfoForNonZeros rowCoeffs return $ msg ++ conflictInfo reportInconsistency :: (?debug :: Bool) => LinearSystem -> [Int] -> State UnitEnv () reportInconsistency (m, v) ns = do uvarEnv <- gets varColEnv debugs <- gets debugInfo -- helper functions let srcLineCompare = compare `on` (srcLine . fst . fst) let nonZeroVectorIndices = V.toList . V.map (+1) . V.findIndices (/= 0) -- examine all row numbers given to us as the parameter vs <- fmap (sortBy srcLineCompare . concat) . forM ns $ \ n -> do -- find out column indices of interest in the row let colsOfInterest = nonZeroVectorIndices (getRow n m) -- for each index of interest in the row, see what other rows also use it vs <- forM colsOfInterest $ \ i -> do let rowsOfInterest = nub . (i:) . nonZeroVectorIndices $ getCol i m -- lookup debug info for those row indices of interest let colDebugs = mapMaybe (flip lookup debugs) $ rowsOfInterest -- also lookup VarBinder info for i and convert it to same format let vs = map (\ (VarBinder (v, s)) -> (s, v)) $ lookupVarBindersByCols uvarEnv [i] return $ vs ++ colDebugs -- flatten it out return (concat vs) report <<++ "Caused by at least one of the following terms:" forM_ (nub vs) $ \ ((s1, _), str) -> do unless (all (\ x -> isNumber x || x == '.' || x == '-') str) $ report <<++ "line " ++ show (srcLine s1) ++ ": " ++ str solveSystemM :: (?solver :: Solver, ?debug :: Bool) => String -> State UnitEnv Bool solveSystemM adjective = do system <- gets linearSystem ifDebug debugGaussian case (solveSystemH_Either system) of Right system' -> do linearSystem =: system' ifDebug (report <<++ "After solve") ifDebug (debugGaussian) return True Left ns -> do report <<++ (adjective ++ " units of measure") reportInconsistency system ns return False -- linearSystem =: system' -- if (adjective `elem` ["inconsistent", "underdetermined"]) then -- do msg <- errorMessage row unit vars -- report <<++ msg -- return False -- else -- return False checkUnderdeterminedM :: State UnitEnv () checkUnderdeterminedM = do ucats <- gets unitVarCats system <- gets linearSystem varenv <- gets varColEnv debugs <- gets debugInfo procenv <- gets procedureEnv let badCols = checkUnderdetermined ucats system uenv <- gets varColEnv if not (null badCols) then do let exprs = map (showExprLines ucats varenv procenv debugs) badCols let exprsL = concat $ intersperse "\n\t" exprs debugGaussian report <<++ "Underdetermined units of measure. Try adding units to: \n\t" ++ exprsL return () else return () underdeterminedCols =: badCols checkUnderdetermined :: [UnitVarCategory] -> LinearSystem -> [Int] checkUnderdetermined ucats system@(matrix, vector) = fixValue (propagateUnderdetermined matrix) $ checkUnderdetermined' ucats system 1 criticalVars :: State UnitEnv [String] criticalVars = do uvarenv <- gets varColEnv (matrix, _) <- gets linearSystem ucats <- gets unitVarCats dbgs <- gets debugInfo -- debugGaussian let cv1 = criticalVars' uvarenv ucats matrix 1 dbgs let cv2 = [] -- criticalVars return (cv1 ++ cv2) criticalVars' :: VarColEnv -> [UnitVarCategory] -> Matrix Rational -> Row -> DebugInfo -> [String] criticalVars' varenv ucats matrix i dbgs = let m = firstNonZeroCoeff matrix ucats in if (i == nrows matrix) then if (m i) /= (ncols matrix) then lookupVarsByColsFilterByArg matrix varenv ucats [((m i) + 1)..(ncols matrix)] dbgs else [] else if (m (i + 1)) /= ((m i) + 1) then (lookupVarsByColsFilterByArg matrix varenv ucats [((m i) + 1)..(m (i + 1) - 1)] dbgs) ++ (criticalVars' varenv ucats matrix (i + 1) dbgs) else criticalVars' varenv ucats matrix (i + 1) dbgs lookupVarsByColsFilterByArg :: Matrix Rational -> VarColEnv -> [UnitVarCategory] -> [Int] -> DebugInfo -> [String] lookupVarsByColsFilterByArg matrix uenv ucats cols dbgs = mapMaybe (\j -> lookupEnv j uenv) cols where lookupEnv j [] = --Nothing if (ucats !! (j - 1) == Temporary && (not (all (==0) (V.toList (getCol j matrix))))) then case (lookup j dbgs) of Just (srcSpan, info) -> Just ("[expr: " ++ (showSrcSpan srcSpan) ++ "@" ++ info ++ "]") Nothing -> Nothing else Nothing lookupEnv j ((VarBinder (v, _), (VarCol i, _)):uenv) | i == j = if (j <= length ucats) then case (ucats !! (j - 1)) of Argument -> Nothing _ -> if (all (==0) (V.toList (getCol j matrix))) then Nothing else Just v else Nothing | otherwise = lookupEnv j uenv firstNonZeroCoeff :: Matrix Rational -> [UnitVarCategory] -> Row -> Col firstNonZeroCoeff matrix ucats row = case (V.findIndex (/= 0) (getRow row matrix)) of Nothing -> ncols matrix Just i -> i + 1 {- firstNonZeroCoeff' (V.toList $ getRow row matrix) 0 where {- -} firstNonZeroCoeff' [] n = n + 1 firstNonZeroCoeff' (0:rs) n = firstNonZeroCoeff' rs (n+1) firstNonZeroCoeff' (r:rs) n = case (ucats !! n) of Literal -> firstNonZeroCoeff' rs (n + 1) _ -> n + 1-} -- debug string ("n = " ++ show n ++ " vc = " ++ (show (vector !! (n - 1))) ++ " ms = " ++ show ms ++ " rest = " ++ show rest) `D.trace` checkUnderdetermined' :: [UnitVarCategory] -> LinearSystem -> Int -> [Int] checkUnderdetermined' ucats system@(matrix, vector) n | n > nrows matrix = [] | not ((drop 1 ms) == []) && vector !! (n - 1) /= Unitful [] = ms ++ rest | otherwise = rest where ms = filter significant [2 .. ncols matrix] significant m = matrix ! (n, m) /= 0 && ucats !! (m - 1) `notElem` [Literal False, Literal True, Argument, Temporary] rest = checkUnderdetermined' ucats system (n + 1) propagateUnderdetermined :: Matrix Rational -> [Int] -> [Int] propagateUnderdetermined matrix list = nub $ do m <- list n <- filter (\n -> matrix ! (n, m) /= 0) [1 .. nrows matrix] filter (\m -> matrix ! (n, m) /= 0) [1 .. ncols matrix] -- ************************************* -- Intrinsic functions: information & -- setup functions for them. -- -- ************************************* intrinsicsDict :: (?assumeLiterals :: AssumeLiterals) => [(String, String -> State UnitEnv ())] intrinsicsDict = map (\x -> (x, addPlain1ArgIntrinsic)) ["ABS", "ACHAR", "ADJUSTL", "ADJUSTR", "AIMAG", "AINT", "ANINT", "CEILING", "CONJG", "DBLE", "EPSILON", "FLOOR","FLOAT", "FRACTION", "HUGE", "IACHAR", "ICHAR", "INT", "IPARITY", "LOGICAL", "MAXEXPONENT", "MINEXPONENT", "NEW_LINE", "NINT", "NORM2", "NOT", "NULL", "PARITY", "REAL", "RRSPACING", "SPACING", "SUM", "TINY", "TRANSPOSE", "TRIM"] ++ map (\x -> (x, addPlain2ArgIntrinsic)) ["ALL", "ANY", "IALL", "IANY", "CHAR", "CMPLX", "DCOMPLX", "DIM", "HYPOT", "IAND", "IEOR", "IOR", "MAX", "MIN", "MAXVAL", "MINVAL","MODULO", "MOD"] ++ map (\x -> (x, addPlain1Arg1ExtraIntrinsic)) ["CSHIFT", "EOSHIFT", "IBCLR", "IBSET", "NEAREST", "PACK", "REPEAT", "RESHAPE", "SHIFTA", "SHIFTL", "SHIFTR", "SIGN"] ++ map (\x -> (x, addPlain2Arg1ExtraIntrinsic)) ["DSHIFTL", "DSHIFTR", "ISHFT", "ISHFTC", "MERGE", "MERGE_BITS"] ++ map (\x -> (x, addProductIntrinsic)) ["DOT_PRODUCT", "DPROD", "MATMUL"] ++ map (\x -> (x, addPowerIntrinsic)) ["SCALE", "SET_EXPONENT"] ++ map (\x -> (x, addUnitlessIntrinsic)) ["ACOS", "ACOSH", "ASIN", "ASINH", "ATAN", "ATANH", "BESSEL_J0", "BESSEL_J1", "BESSEL_Y0", "BESSEL_Y1", "COS", "COSH", "ERF", "ERFC", "ERFC_SCALED", "EXP", "EXPONENT", "GAMMA", "LOG", "ALOG", "LOG10", "LOG_GAMMA", "PRODUCT", "SIN", "SINH", "TAN", "TANH"] ++ map (\x -> (x, addUnitlessSubIntrinsic)) ["CPU_TIME", "RANDOM_NUMBER"] ++ map (\x -> (x, addUnitlessResult0ArgIntrinsic)) ["COMMAND_ARGUMENT_COUNT", "COMPILER_OPTIONS", "COMPILER_VERSION"] ++ map (\x -> (x, addUnitlessResult1ArgIntrinsic)) ["ALLOCATED", "ASSOCIATED", "BIT_SIZE", "COUNT", "DIGITS", "IS_IOSTAT_END", "IS_IOSTAT_EOR", "KIND", "LBOUND", "LCOBOUND", "LEADZ", "LEN", "LEN_TRIM", "MASKL", "MASKR", "MAXLOC", "MINLOC", "POPCOUNT", "POPPAR", "PRECISION", "PRESENT", "RADIX", "RANGE", "SELECTED_CHAR_KIND", "SELECTED_INT_KIND", "SELECTED_REAL_KIND", "SHAPE", "SIZE", "STORAGE_SIZE", "TRAILZ", "UBOUND", "UCOBOUND"] ++ map (\x -> (x, addUnitlessResult2SameArgIntrinsic)) ["ATAN2", "BGE", "BGT", "BLE", "BLT", "INDEX", "LGE", "LGT", "LLE", "LLT", "SCAN", "VERIFY"] ++ map (\x -> (x, addUnitlessResult2AnyArgIntrinsic)) ["BTEST", "EXTENDS_TYPE_OF", "SAME_TYPE_AS"] -- missing: ATOMIC_DEFINE, ATOMIC_REF, BESSEL_JN, BESSEL_YN, C_*, DATE_AND_TIME, EXECUTE_COMMAND_LINE, GET_COMMAND, GET_COMMAND_ARGUMENT, GET_ENVIRONMENT_VARIABLE, IBITS, any of the image stuff, MOVE_ALLOC, MVBITS, RANDOM_SEED, SPREAD, SYSTEM_CLOCK, TRANSFER, UNPACK {- [A] Various helpers for adding information about procedures to the type system -} addPlain1ArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addPlain1ArgIntrinsic name = do result <- anyUnits Variable arg <- anyUnits Argument mustEqual False result arg procedureEnv << (name, (Just result, [arg])) addPlain2ArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addPlain2ArgIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument mustEqual False result arg1 mustEqual False result arg2 procedureEnv << (name, (Just result, [arg1, arg2])) addPlain1Arg1ExtraIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addPlain1Arg1ExtraIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument mustEqual False result arg1 procedureEnv << (name, (Just result, [arg1, arg2])) addPlain2Arg1ExtraIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addPlain2Arg1ExtraIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument arg3 <- anyUnits Argument mustEqual False result arg1 mustEqual False result arg2 procedureEnv << (name, (Just result, [arg1, arg2, arg3])) addProductIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addProductIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument temp <- mustAddUp arg1 arg2 1 1 mustEqual False result temp procedureEnv << (name, (Just result, [arg1, arg2])) addPowerIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addPowerIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument mustEqual False result arg1 mustEqual False arg2 (VarCol 1) procedureEnv << (name, (Just result, [arg1, arg2])) addUnitlessIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessIntrinsic name = do result <- anyUnits Variable arg <- anyUnits Argument mustEqual False result (VarCol 1) mustEqual False arg (VarCol 1) procedureEnv << (name, (Just result, [arg])) addUnitlessSubIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessSubIntrinsic name = do arg <- anyUnits Variable mustEqual False arg (VarCol 1) procedureEnv << (name, (Nothing, [arg])) addUnitlessResult0ArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessResult0ArgIntrinsic name = do result <- anyUnits Variable mustEqual False result (VarCol 1) procedureEnv << (name, (Just result, [])) addUnitlessResult1ArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessResult1ArgIntrinsic name = do result <- anyUnits Variable arg <- anyUnits Argument mustEqual False result (VarCol 1) procedureEnv << (name, (Just result, [arg])) addUnitlessResult2AnyArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessResult2AnyArgIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument mustEqual False result (VarCol 1) procedureEnv << (name, (Just result, [arg1, arg2])) addUnitlessResult2SameArgIntrinsic :: (?assumeLiterals :: AssumeLiterals) => String -> State UnitEnv () addUnitlessResult2SameArgIntrinsic name = do result <- anyUnits Variable arg1 <- anyUnits Argument arg2 <- anyUnits Argument mustEqual False result (VarCol 1) mustEqual False arg1 arg2 procedureEnv << (name, (Just result, [arg1, arg2])) -- ************************************* -- Debugging and testing functions -- -- ************************************* -- QuickCheck instance for matrices, used for testing matrix operations {- instance (Arbitrary a) => Arbitrary (Matrix a) where arbitrary = sized (\n -> do xs <- vectorOf (n*n) arbitrary return $ matrix n n (\(i, j) -> xs !! ((i-1)*n + (j-1)))) -} -- Matrix for development fooMatrix :: Matrix Rational fooMatrix = matrix 4 4 $ (\(i,j) -> if (i==j) then (toInteger i) % 1 else 0) {-| debugGaussian - a debugging routine which shose the Gaussian matrix with various peieces of info mainly used for development purposes -} debugGaussian :: State UnitEnv String debugGaussian = do grid' <- debugGaussian' report <<++ ("Dump of units-of-measure system matrix\n" ++ grid') return grid' debugGaussian' = do ucats <- gets unitVarCats (matrix,rowv) <- gets linearSystem varenv <- gets varColEnv debugs <- gets debugInfo procenv <- gets procedureEnv let -- Column headings and then a space grid = ["" : map show [1..(ncols matrix)], []] -- Gaussian matrix ++ map (\r -> (show r) : (map showRational $ V.toList $ getRow r matrix) ++ [show $ rowv !! (r - 1)]) [1..(nrows matrix)] -- Column categories ++ [[], "" : map showCat ucats] -- Debug info, e.g., expression or variable ++ ["" : map (showExpr ucats varenv procenv debugs) [1.. (ncols matrix)]] -- Additional debug info for args that are also variables ++ ["" : map (showArgVars ucats varenv) [1..(ncols matrix)]] let colSize = maximum' (map maximum' (map (notLast . (map length)) grid)) let expand r = r ++ (replicate (colSize - length r) ' ') let showLine x = (concatMap expand x) ++ "\n" let grid' = concatMap showLine grid return grid' where maximum' [] = 0 maximum' xs = maximum xs notLast xs = take (length xs - 1) xs showExpr cats vars procs debugInfo c = case (cats !! (c - 1)) of Variable -> case (lookupVarsByCols vars [c]) of [] -> case (lookupProcByCols procs [c]) of [] -> "?" (x:_) -> "=" ++ x (x:_) -> x Temporary -> snd $ case (lookup c debugInfo) of Just x -> x Nothing -> (undefined, "") -- error $ "Temporary fail " ++ (show c) " not in " ++ (show cats) Argument -> case (lookupProcByArgCol procs [c]) of [] -> "?" (x:_) -> x Literal _ -> snd $ case (lookup c debugInfo) of Just x -> x Nothing -> show c `D.trace` error "Literal fail" Magic -> "" showSrcLoc loc = show (srcLine loc) ++ ":" ++ show (srcColumn loc) showSrcSpan (start, end) = "(" ++ showSrcLoc start ++ " - " ++ showSrcLoc end ++ ")" showSrcFile (start, _) = srcFilename start showExprLines cats vars procs debugInfo c = case (cats !! (c - 1)) of Variable -> case (lookup c debugInfo) of Just (sp, expr) -> (showSrcSpan sp) ++ "\t" ++ expr Nothing -> case (lookupVarsByCols vars [c]) of [] -> case (lookupProcByCols procs [c]) of [] -> "?" (x:_) -> "=" ++ x (x:_) -> x Temporary -> let (sp, expr) = fromJust $ lookup c debugInfo in (showSrcSpan sp) ++ "\t" ++ expr Argument -> case (lookupProcByArgCol procs [c]) of [] -> "?" (x:_) -> x Literal _ -> let (sp, expr) = fromJust $ lookup c debugInfo in (showSrcSpan sp) ++ "\t" ++ expr Magic -> "" showArgVars cats vars c = case (cats !! (c - 1)) of Argument -> case (lookupVarsByCols vars [c]) of [] -> "" (x:_) -> x _ -> "" showCat Variable = "Var" showCat Magic = "Magic" showCat Temporary = "Temp" showCat Argument = "Arg" showCat (Literal False) = "Lit" showCat (Literal True) = "Lit=" lookupProcByArgCol :: ProcedureEnv -> [Int] -> [String] lookupProcByArgCol penv cols = mapMaybe (\j -> lookupEnv j penv) cols where lookupEnv j [] = Nothing lookupEnv j ((p, (_, args)):penv) | elem (VarCol j) args = Just (p ++ "#" ++ (show $ fromJust $ elemIndex (VarCol j) args)) | otherwise = lookupEnv j penv lookupProcByCols :: ProcedureEnv -> [Int] -> [String] lookupProcByCols penv cols = mapMaybe (\j -> lookupEnv j penv) cols where lookupEnv j [] = Nothing lookupEnv j ((p, (Just (VarCol i), _)):penv) | i == j = Just p | otherwise = lookupEnv j penv lookupEnv j ((p, (Nothing, _)):penv) = lookupEnv j penv lookupVarsByCols :: VarColEnv -> [Int] -> [Variable] lookupVarsByCols uenv cols = mapMaybe (\j -> lookupEnv j uenv) cols where lookupEnv j [] = Nothing lookupEnv j ((VarBinder (v, _), (VarCol i, _)):uenv) | i == j = Just v | otherwise = lookupEnv j uenv lookupVarBindersByCols :: VarColEnv -> [Int] -> [VarBinder] lookupVarBindersByCols uenv cols = mapMaybe (\j -> lookupEnv j uenv) cols where lookupEnv j [] = Nothing lookupEnv j ((vb@(VarBinder (v, _)), (VarCol i, _)):uenv) | i == j = Just vb | otherwise = lookupEnv j uenv showRational r = show (numerator r) ++ if ((denominator r) == 1) then "" else "%" ++ (show $ denominator r) -- ************************************* -- Insert unit declarations into code -- -- ************************************* insertUnitsInBlock :: Block Annotation -> State UnitEnv (Block Annotation) insertUnitsInBlock x = transformBiM insertUnits x removeUnitsInBlock :: Block Annotation -> Block Annotation removeUnitsInBlock = transformBi deleteUnits convertUnit :: MeasureUnitSpec a -> State UnitEnv UnitConstant convertUnit (UnitProduct _ units) = convertUnits units convertUnit (UnitQuotient _ units1 units2) = liftM2 (-) (convertUnits units1) (convertUnits units2) convertUnit (UnitNone _) = return $ Unitful [] convertUnits :: [(MeasureUnit, Fraction a)] -> State UnitEnv UnitConstant convertUnits units = foldl (+) (Unitful []) `liftM` sequence [convertSingleUnit unit (fromFraction f) | (unit, f) <- units] convertSingleUnit :: MeasureUnit -> Rational -> State UnitEnv UnitConstant convertSingleUnit unit f = do denv <- gets derivedUnitEnv let uc f' = Unitful [(unit, f')] case lookup unit denv of Just uc' -> return $ uc' * (fromRational f) Nothing -> derivedUnitEnv << (unit, uc 1) >> return (uc f) fromFraction :: Fraction a -> Rational fromFraction (IntegerConst _ n) = fromInteger $ read n fromFraction (FractionConst _ p q) = fromInteger (read p) / fromInteger (read q) fromFraction (NullFraction _) = 1 extractUnit :: Attr a -> [State UnitEnv UnitConstant] extractUnit attr = case attr of MeasureUnit _ unit -> [convertUnit unit] _ -> [] lookupUnit :: [UnitVarCategory] -> [Int] -> LinearSystem -> Col -> Maybe UnitConstant lookupUnit ucats badCols system@(matrix, vector) m = let -- m is the column corresopnding to the variable for which we are looking up the unit n = find (\n -> matrix ! (n, m) /= 0) [1 .. nrows matrix] defaultUnit = if ucats !! (m - 1) == Argument then Nothing else Just (Unitful []) in maybe defaultUnit (lookupUnit' ucats badCols system m) n lookupUnit' :: [UnitVarCategory] -> [Int] -> LinearSystem -> Int -> Int -> Maybe UnitConstant lookupUnit' ucats badCols (matrix, vector) m n | not $ null ms = Nothing | ucats !! (m - 1) /= Argument && m `notElem` badCols = Just $ vector !! (n - 1) | ms' /= [m] = Nothing | otherwise = Just $ vector !! (n - 1) where ms = filter significant [1 .. ncols matrix] significant m' = m' /= m && matrix ! (n, m') /= 0 && ucats !! (m' - 1) == Argument ms' = filter (\m -> matrix ! (n, m) /= 0) [1 .. ncols matrix] insertUnits :: Decl Annotation -> State UnitEnv (Decl Annotation) insertUnits decl@(Decl a sp@(s1, s2) d t) | not (pRefactored a || hasUnits t) = do system <- gets linearSystem ucats <- gets unitVarCats badCols <- gets underdeterminedCols vColEnv <- gets varColEnv let varCol (Var _ s ((VarName _ v, _):_), _, _) = case (lookupWithSrcSpan v s (vColEnv)) of (Just (VarCol m,_)) -> m Nothing -> error $ "No variable " ++ (show v) let sameUnits = (==) `on` (lookupUnit ucats badCols system . varCol) let groups = groupBy sameUnits d types <- mapM (\g -> let ?num = length g in insertUnit ucats badCols system t . varCol . head $ g) groups let a' = a { refactored = Just s1 } let sp' = dropLine $ refactorSpan sp let sp'' = (toCol0 s1, snd $ dropLine sp) let decls = [Decl a' sp' group t' | (group, t') <- zip groups types] if (not (types == [t])) then return $ DSeq a (NullDecl a' sp'') (foldr1 (DSeq a) decls) else return $ decl insertUnits decl = return decl deleteUnits :: Decl Annotation -> Decl Annotation deleteUnits (Decl a sp@(s1, s2) d t) | hasUnits t = Decl a' (dropLine sp) d t' where a' = a { refactored = Just $ toCol0 s1 } t' = deleteUnit t deleteUnits (MeasureUnitDef a sp@(s1, s2) d) = NullDecl a' sp' where a' = a { refactored = Just s1 } sp' = (toCol0 s1, snd $ dropLine sp) deleteUnits decl = decl hasUnits :: Type a -> Bool hasUnits (BaseType _ _ attrs _ _) = any isUnit attrs hasUnits (ArrayT _ _ _ attrs _ _) = any isUnit attrs isUnit :: Attr a -> Bool isUnit (MeasureUnit _ _) = True isUnit _ = False insertUnit :: (?num :: Int) => [UnitVarCategory] -> [Int] -> LinearSystem -> Type Annotation -> Int -> State UnitEnv (Type Annotation) insertUnit ucats badCols system (BaseType aa tt attrs kind len) uv = do let unit = lookupUnit ucats badCols system uv u <- (insertUnitAttribute unit attrs) return $ BaseType aa tt u kind len insertUnit ucats badCols system (ArrayT dims aa tt attrs kind len) uv = do let unit = lookupUnit ucats badCols system uv u <- insertUnitAttribute unit attrs return $ ArrayT dims aa tt u kind len deleteUnit :: Type Annotation -> Type Annotation deleteUnit (BaseType aa tt attrs kind len) = BaseType aa tt (filter (not . isUnit) attrs) kind len deleteUnit (ArrayT dims aa tt attrs kind len) = ArrayT dims aa tt (filter (not . isUnit) attrs) kind len insertUnitAttribute :: (?num :: Int) => Maybe UnitConstant -> [Attr Annotation] -> State UnitEnv [Attr Annotation] insertUnitAttribute (Just unit) attrs = do spec <- makeUnitSpec unit return $ attrs ++ [MeasureUnit unitAnnotation $ spec] insertUnitAttribute Nothing attrs = return attrs -- Used for evaluation updateAdded k s = do (n, xs) <- gets evUnitsAdded let k' = if k == 0 then 1 else k evUnitsAdded =: (n + k, xs ++ [s]) makeUnitSpec :: (?num :: Int) => UnitConstant -> State UnitEnv (MeasureUnitSpec Annotation) makeUnitSpec (UnitlessC r) = do let u = UnitProduct unitAnnotation [("1", (FractionConst unitAnnotation (show $ numerator r) (show $ denominator r)))] --hm! updateAdded ?num (pprint u) return $ u makeUnitSpec (Unitful []) = return $ UnitNone unitAnnotation makeUnitSpec (Unitful units) | null neg = let u = UnitProduct unitAnnotation $ formatUnits pos in do updateAdded ?num (pprint u) return u | otherwise = let u = UnitQuotient unitAnnotation (formatUnits pos) (formatUnits neg) in do updateAdded ?num (pprint u) return u where pos = filter (\(unit, r) -> r > 0) units neg = [(unit, -r) | (unit, r) <- units, r < 0] formatUnits :: [(MeasureUnit, Rational)] -> [(MeasureUnit, Fraction Annotation)] formatUnits units = [(unit, toFraction r) | (unit, r) <- units] toFraction :: Rational -> Fraction Annotation toFraction 1 = NullFraction unitAnnotation toFraction r | q == 1 = IntegerConst unitAnnotation $ show p | otherwise = FractionConst unitAnnotation (show p) (show q) where p = numerator r q = denominator r