{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict            #-}
module Tokstyle.C.Linter.Memset (analyse) where

import           Control.Monad                   (unless)
import           Data.Functor.Identity           (Identity)
import           Language.C.Analysis.AstAnalysis (ExprSide (..), tExpr)
import           Language.C.Analysis.DefTable    (lookupTag)
import           Language.C.Analysis.SemError    (typeMismatch)
import           Language.C.Analysis.SemRep      (CompType (..),
                                                  CompTypeRef (..), GlobalDecls,
                                                  MemberDecl (..), TagDef (..),
                                                  Type (..), TypeName (..),
                                                  VarDecl (..))
import           Language.C.Analysis.TravMonad   (MonadTrav, Trav, TravT,
                                                  getDefTable, recordError,
                                                  throwTravError)
import           Language.C.Analysis.TypeUtils   (canonicalType)
import           Language.C.Data.Error           (userErr)
import           Language.C.Data.Ident           (Ident (..))
import           Language.C.Pretty               (pretty)
import           Language.C.Syntax.AST           (CExpression (..), annotation)
import           Tokstyle.C.Env                  (Env)
import           Tokstyle.C.TraverseAst          (AstActions (..), astActions,
                                                  traverseAst)


hasPtrs :: MonadTrav m => Type -> m Bool
hasPtrs :: Type -> m Bool
hasPtrs Type
ty = case Type -> Type
canonicalType Type
ty of
    DirectType (TyComp (CompTypeRef SUERef
name CompTyKind
_ NodeInfo
_)) TypeQuals
_ Attributes
_ -> do
        DefTable
defs <- m DefTable
forall (m :: * -> *). MonadSymtab m => m DefTable
getDefTable
        case SUERef -> DefTable -> Maybe TagEntry
lookupTag SUERef
name DefTable
defs of
            Just (Right (CompDef (CompType SUERef
_ CompTyKind
_ [MemberDecl]
members Attributes
_ NodeInfo
_))) ->
                [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> m [Bool] -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MemberDecl -> m Bool) -> [MemberDecl] -> m [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM MemberDecl -> m Bool
forall (m :: * -> *). MonadTrav m => MemberDecl -> m Bool
memberHasPtrs [MemberDecl]
members
            Maybe TagEntry
_ ->
                UserError -> m Bool
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m Bool) -> UserError -> m Bool
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
                    String
"couldn't find struct/union type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (SUERef -> Doc
forall p. Pretty p => p -> Doc
pretty SUERef
name) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`"
    PtrType{} -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    Type
_ -> Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

memberHasPtrs :: MonadTrav m => MemberDecl -> m Bool
memberHasPtrs :: MemberDecl -> m Bool
memberHasPtrs (MemberDecl (VarDecl VarName
_ DeclAttrs
_ Type
ty) Maybe Expr
_ NodeInfo
_) = Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
ty
memberHasPtrs MemberDecl
_                                 = Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False


memsetAllowed :: MonadTrav m => Type -> m Bool
memsetAllowed :: Type -> m Bool
memsetAllowed Type
ty = case Type -> Type
canonicalType Type
ty of
    PtrType Type
pointee TypeQuals
_ Attributes
_ -> Bool -> Bool
not (Bool -> Bool) -> m Bool -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
pointee
    ArrayType Type
memTy ArraySize
_ TypeQuals
_ Attributes
_ -> Bool -> Bool
not (Bool -> Bool) -> m Bool -> m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> m Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
hasPtrs Type
memTy
    Type
_ ->
        UserError -> m Bool
forall (m :: * -> *) e a. (MonadCError m, Error e) => e -> m a
throwTravError (UserError -> m Bool) -> UserError -> m Bool
forall a b. (a -> b) -> a -> b
$ String -> UserError
userErr (String -> UserError) -> String -> UserError
forall a b. (a -> b) -> a -> b
$
            String
"value of type `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` cannot be passed to memset"


linter :: AstActions (TravT Env Identity)
linter :: AstActions (TravT Env Identity)
linter = AstActions (TravT Env Identity)
forall (f :: * -> *). Applicative f => AstActions f
astActions
    { doExpr :: Expr -> TravT Env Identity () -> TravT Env Identity ()
doExpr = \Expr
node TravT Env Identity ()
act -> case Expr
node of
        CCall (CVar (Ident String
"memset" Int
_ NodeInfo
_) NodeInfo
_) [Expr
s, Expr
_, Expr
_] NodeInfo
_ -> do
            Type
ty <- [StmtCtx] -> ExprSide -> Expr -> TravT Env Identity Type
forall (m :: * -> *).
MonadTrav m =>
[StmtCtx] -> ExprSide -> Expr -> m Type
tExpr [] ExprSide
RValue Expr
s
            Bool
allowed <- Type -> TravT Env Identity Bool
forall (m :: * -> *). MonadTrav m => Type -> m Bool
memsetAllowed Type
ty
            Bool -> TravT Env Identity () -> TravT Env Identity ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
allowed (TravT Env Identity () -> TravT Env Identity ())
-> TravT Env Identity () -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$ do
                let annot :: (NodeInfo, Type)
annot = (Expr -> NodeInfo
forall (ast :: * -> *) a. Annotated ast => ast a -> a
annotation Expr
s, Type
ty)
                TypeMismatch -> TravT Env Identity ()
forall (m :: * -> *) e. (MonadCError m, Error e) => e -> m ()
recordError (TypeMismatch -> TravT Env Identity ())
-> TypeMismatch -> TravT Env Identity ()
forall a b. (a -> b) -> a -> b
$ String -> (NodeInfo, Type) -> (NodeInfo, Type) -> TypeMismatch
typeMismatch
                    (String
"disallowed memset argument `" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Expr -> Doc
forall p. Pretty p => p -> Doc
pretty Expr
s) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"` of type `"
                     String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Doc -> String
forall a. Show a => a -> String
show (Type -> Doc
forall p. Pretty p => p -> Doc
pretty Type
ty) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"`, which contains pointers") (NodeInfo, Type)
annot (NodeInfo, Type)
annot
            TravT Env Identity ()
act

        Expr
_ -> TravT Env Identity ()
act
    }


analyse :: GlobalDecls -> Trav Env ()
analyse :: GlobalDecls -> TravT Env Identity ()
analyse = AstActions (TravT Env Identity)
-> GlobalDecls -> TravT Env Identity ()
forall a (f :: * -> *).
(TraverseAst a, Applicative f) =>
AstActions f -> a -> f ()
traverseAst AstActions (TravT Env Identity)
linter