{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}

module NVD
  ( withVulnDB,
    getCVEs,
    Connection,
    ProductID,
    Version,
    CVE,
    CVEID,
    UTCTime,
  )
where

import CVE
  ( CPEMatch (..),
    CPEMatchRow (..),
    CVE (..),
    CVEID,
    cpeMatches,
    parseFeed,
  )
import Codec.Compression.GZip (decompress)
import Control.Exception (SomeException, try)
import Crypto.Hash.SHA256 (hashlazy)
import qualified Data.ByteString.Lazy.Char8 as BSL
import Data.Hex (hex, unhex)
import Data.List (group)
import qualified Data.Text as T
import Data.Time.Calendar (toGregorian)
import Data.Time.Clock
  ( UTCTime,
    diffUTCTime,
    getCurrentTime,
    nominalDay,
    utctDay,
  )
import Data.Time.ISO8601 (parseISO8601)
import Database.SQLite.Simple
  ( Connection,
    Only (..),
    Query (..),
    execute,
    executeMany,
    execute_,
    query,
    withConnection,
    withTransaction,
  )
import qualified NVDRules
import Network.HTTP.Conduit (simpleHttp)
import OurPrelude
import System.Directory
  ( XdgDirectory (..),
    createDirectoryIfMissing,
    getXdgDirectory,
    removeFile,
  )
import System.FilePath ((</>))
import Utils (ProductID, Version)
import Version (matchVersion)

-- | Either @recent@, @modified@, or any year since @2002@.
type FeedID = String

type Extension = String

type Timestamp = UTCTime

type Checksum = BSL.ByteString

type DBVersion = Int

data Meta
  = Meta Timestamp Checksum

-- | Database version the software expects. If the software version is
-- higher than the database version or the database has not been updated in more
-- than 7.5 days, the database will be deleted and rebuilt from scratch. Bump
-- this when the database layout changes or the build-time data filtering
-- changes.
softwareVersion :: DBVersion
softwareVersion :: DBVersion
softwareVersion = DBVersion
2

getDBPath :: IO FilePath
getDBPath :: IO FilePath
getDBPath = do
  FilePath
cacheDir <- XdgDirectory -> FilePath -> IO FilePath
getXdgDirectory XdgDirectory
XdgCache FilePath
"nixpkgs-update"
  Bool -> FilePath -> IO ()
createDirectoryIfMissing Bool
True FilePath
cacheDir
  FilePath -> IO FilePath
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FilePath -> IO FilePath) -> FilePath -> IO FilePath
forall a b. (a -> b) -> a -> b
$ FilePath
cacheDir FilePath -> FilePath -> FilePath
</> FilePath
"nvd.sqlite3"

withDB :: (Connection -> IO a) -> IO a
withDB :: (Connection -> IO a) -> IO a
withDB Connection -> IO a
action = do
  FilePath
dbPath <- IO FilePath
getDBPath
  FilePath -> (Connection -> IO a) -> IO a
forall a. FilePath -> (Connection -> IO a) -> IO a
withConnection FilePath
dbPath Connection -> IO a
action

markUpdated :: Connection -> IO ()
markUpdated :: Connection -> IO ()
markUpdated Connection
conn = do
  UTCTime
now <- IO UTCTime
getCurrentTime
  Connection -> Query -> [UTCTime] -> IO ()
forall q. ToRow q => Connection -> Query -> q -> IO ()
execute Connection
conn Query
"UPDATE meta SET last_update = ?" [UTCTime
now]

-- | Rebuild the entire database, redownloading all data.
rebuildDB :: IO ()
rebuildDB :: IO ()
rebuildDB = do
  FilePath
dbPath <- IO FilePath
getDBPath
  FilePath -> IO ()
removeFile FilePath
dbPath
  FilePath -> (Connection -> IO ()) -> IO ()
forall a. FilePath -> (Connection -> IO a) -> IO a
withConnection FilePath
dbPath ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE TABLE meta (db_version int, last_update text)"
    Connection -> Query -> (DBVersion, Text) -> IO ()
forall q. ToRow q => Connection -> Query -> q -> IO ()
execute
      Connection
conn
      Query
"INSERT INTO meta VALUES (?, ?)"
      (DBVersion
softwareVersion, Text
"1970-01-01 00:00:00" :: Text)
    Connection -> Query -> IO ()
execute_ Connection
conn (Query -> IO ()) -> Query -> IO ()
forall a b. (a -> b) -> a -> b
$
      Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
        [Text] -> Text
T.unlines
          [ Text
"CREATE TABLE cves (",
            Text
"  cve_id text PRIMARY KEY,",
            Text
"  description text,",
            Text
"  published text,",
            Text
"  modified text)"
          ]
    Connection -> Query -> IO ()
execute_ Connection
conn (Query -> IO ()) -> Query -> IO ()
forall a b. (a -> b) -> a -> b
$
      Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
        [Text] -> Text
T.unlines
          [ Text
"CREATE TABLE cpe_matches (",
            Text
"  cve_id text REFERENCES cve,",
            Text
"  part text,",
            Text
"  vendor text,",
            Text
"  product text,",
            Text
"  version text,",
            Text
"  \"update\" text,",
            Text
"  edition text,",
            Text
"  language text,",
            Text
"  software_edition text,",
            Text
"  target_software text,",
            Text
"  target_hardware text,",
            Text
"  other text,",
            Text
"  matcher text)"
          ]
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE INDEX matchers_by_cve ON cpe_matches(cve_id)"
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE INDEX matchers_by_product ON cpe_matches(product)"
    Connection -> Query -> IO ()
execute_ Connection
conn Query
"CREATE INDEX matchers_by_vendor ON cpe_matches(vendor)"
    Connection -> Query -> IO ()
execute_
      Connection
conn
      Query
"CREATE INDEX matchers_by_target_software ON cpe_matches(target_software)"
    [FilePath]
years <- IO [FilePath]
allYears
    [FilePath] -> (FilePath -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [FilePath]
years ((FilePath -> IO ()) -> IO ()) -> (FilePath -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> FilePath -> IO ()
updateFeed Connection
conn
    Connection -> IO ()
markUpdated Connection
conn

feedURL :: FeedID -> Extension -> String
feedURL :: FilePath -> FilePath -> FilePath
feedURL FilePath
feed FilePath
ext =
  FilePath
"https://nvd.nist.gov/feeds/json/cve/1.1/nvdcve-1.1-" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
feed FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
ext

throwString :: String -> IO a
throwString :: FilePath -> IO a
throwString = IOError -> IO a
forall a. IOError -> IO a
ioError (IOError -> IO a) -> (FilePath -> IOError) -> FilePath -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> IOError
userError

throwText :: Text -> IO a
throwText :: Text -> IO a
throwText = FilePath -> IO a
forall a. FilePath -> IO a
throwString (FilePath -> IO a) -> (Text -> FilePath) -> Text -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> FilePath
T.unpack

allYears :: IO [FeedID]
allYears :: IO [FilePath]
allYears = do
  UTCTime
now <- IO UTCTime
getCurrentTime
  let (Integer
year, DBVersion
_, DBVersion
_) = Day -> (Integer, DBVersion, DBVersion)
toGregorian (Day -> (Integer, DBVersion, DBVersion))
-> Day -> (Integer, DBVersion, DBVersion)
forall a b. (a -> b) -> a -> b
$ UTCTime -> Day
utctDay UTCTime
now
  [FilePath] -> IO [FilePath]
forall (m :: * -> *) a. Monad m => a -> m a
return ([FilePath] -> IO [FilePath]) -> [FilePath] -> IO [FilePath]
forall a b. (a -> b) -> a -> b
$ (Integer -> FilePath) -> [Integer] -> [FilePath]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> FilePath
forall a. Show a => a -> FilePath
show [Integer
2002 .. Integer
year]

parseMeta :: BSL.ByteString -> Either T.Text Meta
parseMeta :: ByteString -> Either Text Meta
parseMeta ByteString
raw = do
  let splitLine :: ByteString -> (ByteString, ByteString)
splitLine = (ByteString -> ByteString)
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ByteString -> ByteString
BSL.tail ((ByteString, ByteString) -> (ByteString, ByteString))
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> (ByteString, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BSL.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') (ByteString -> (ByteString, ByteString))
-> (ByteString -> ByteString)
-> ByteString
-> (ByteString, ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> ByteString -> ByteString
BSL.takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'\r')
  let fields :: [(ByteString, ByteString)]
fields = (ByteString -> (ByteString, ByteString))
-> [ByteString] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> (ByteString, ByteString)
splitLine ([ByteString] -> [(ByteString, ByteString)])
-> [ByteString] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString]
BSL.lines ByteString
raw
  ByteString
lastModifiedDate <-
    Text -> Maybe ByteString -> Either Text ByteString
forall a b. a -> Maybe b -> Either a b
note Text
"no lastModifiedDate in meta" (Maybe ByteString -> Either Text ByteString)
-> Maybe ByteString -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"lastModifiedDate" [(ByteString, ByteString)]
fields
  ByteString
sha256 <- Text -> Maybe ByteString -> Either Text ByteString
forall a b. a -> Maybe b -> Either a b
note Text
"no sha256 in meta" (Maybe ByteString -> Either Text ByteString)
-> Maybe ByteString -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"sha256" [(ByteString, ByteString)]
fields
  UTCTime
timestamp <-
    Text -> Maybe UTCTime -> Either Text UTCTime
forall a b. a -> Maybe b -> Either a b
note Text
"invalid lastModifiedDate in meta" (Maybe UTCTime -> Either Text UTCTime)
-> Maybe UTCTime -> Either Text UTCTime
forall a b. (a -> b) -> a -> b
$
      FilePath -> Maybe UTCTime
parseISO8601 (FilePath -> Maybe UTCTime) -> FilePath -> Maybe UTCTime
forall a b. (a -> b) -> a -> b
$
        ByteString -> FilePath
BSL.unpack ByteString
lastModifiedDate
  ByteString
checksum <- Text -> Maybe ByteString -> Either Text ByteString
forall a b. a -> Maybe b -> Either a b
note Text
"invalid sha256 in meta" (Maybe ByteString -> Either Text ByteString)
-> Maybe ByteString -> Either Text ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe ByteString
forall t (m :: * -> *). (Hex t, MonadFail m) => t -> m t
unhex ByteString
sha256
  Meta -> Either Text Meta
forall (m :: * -> *) a. Monad m => a -> m a
return (Meta -> Either Text Meta) -> Meta -> Either Text Meta
forall a b. (a -> b) -> a -> b
$ UTCTime -> ByteString -> Meta
Meta UTCTime
timestamp ByteString
checksum

getMeta :: FeedID -> IO Meta
getMeta :: FilePath -> IO Meta
getMeta FilePath
feed = do
  ByteString
raw <- FilePath -> IO ByteString
forall (m :: * -> *). MonadIO m => FilePath -> m ByteString
simpleHttp (FilePath -> IO ByteString) -> FilePath -> IO ByteString
forall a b. (a -> b) -> a -> b
$ FilePath -> FilePath -> FilePath
feedURL FilePath
feed FilePath
".meta"
  (Text -> IO Meta)
-> (Meta -> IO Meta) -> Either Text Meta -> IO Meta
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> IO Meta
forall a. Text -> IO a
throwText Meta -> IO Meta
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Text Meta -> IO Meta) -> Either Text Meta -> IO Meta
forall a b. (a -> b) -> a -> b
$ ByteString -> Either Text Meta
parseMeta ByteString
raw

getCVE :: Connection -> CVEID -> IO CVE
getCVE :: Connection -> Text -> IO CVE
getCVE Connection
conn Text
cveID_ = do
  [CVE]
cves <-
    Connection -> Query -> Only Text -> IO [CVE]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query
      Connection
conn
      ( Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
T.unlines
            [ Text
"SELECT cve_id, description, published, modified",
              Text
"FROM cves",
              Text
"WHERE cve_id = ?"
            ]
      )
      (Text -> Only Text
forall a. a -> Only a
Only Text
cveID_)
  case [CVE]
cves of
    [CVE
cve] -> CVE -> IO CVE
forall (f :: * -> *) a. Applicative f => a -> f a
pure CVE
cve
    [] -> FilePath -> IO CVE
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO CVE) -> FilePath -> IO CVE
forall a b. (a -> b) -> a -> b
$ FilePath
"no cve with id " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> (Text -> FilePath
T.unpack Text
cveID_)
    [CVE]
_ -> FilePath -> IO CVE
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail (FilePath -> IO CVE) -> FilePath -> IO CVE
forall a b. (a -> b) -> a -> b
$ FilePath
"multiple cves with id " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> (Text -> FilePath
T.unpack Text
cveID_)

getCVEs :: Connection -> ProductID -> Version -> IO [CVE]
getCVEs :: Connection -> Text -> Text -> IO [CVE]
getCVEs Connection
conn Text
productID Text
version = do
  [CPEMatchRow]
matches :: [CPEMatchRow] <-
    Connection
-> Query -> (Text, Text, Text, Text, Text) -> IO [CPEMatchRow]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query
      Connection
conn
      ( Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
T.unlines
            [ Text
"SELECT",
              Text
"  cve_id,",
              Text
"  part,",
              Text
"  vendor,",
              Text
"  product,",
              Text
"  version,",
              Text
"  \"update\",",
              Text
"  edition,",
              Text
"  language,",
              Text
"  software_edition,",
              Text
"  target_software,",
              Text
"  target_hardware,",
              Text
"  other,",
              Text
"  matcher",
              Text
"FROM cpe_matches",
              Text
"WHERE vendor = ? or product = ? or edition = ? or software_edition = ? or target_software = ?",
              Text
"ORDER BY cve_id"
            ]
      )
      (Text
productID, Text
productID, Text
productID, Text
productID, Text
productID)
  let cveIDs :: [Text]
cveIDs =
        ([Text] -> Text) -> [[Text]] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map [Text] -> Text
forall a. [a] -> a
head ([[Text]] -> [Text]) -> [[Text]] -> [Text]
forall a b. (a -> b) -> a -> b
$
          [Text] -> [[Text]]
forall a. Eq a => [a] -> [[a]]
group ([Text] -> [[Text]]) -> [Text] -> [[Text]]
forall a b. (a -> b) -> a -> b
$
            ((CPEMatchRow -> Maybe Text) -> [CPEMatchRow] -> [Text])
-> [CPEMatchRow] -> (CPEMatchRow -> Maybe Text) -> [Text]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (CPEMatchRow -> Maybe Text) -> [CPEMatchRow] -> [Text]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [CPEMatchRow]
matches ((CPEMatchRow -> Maybe Text) -> [Text])
-> (CPEMatchRow -> Maybe Text) -> [Text]
forall a b. (a -> b) -> a -> b
$
              \(CPEMatchRow CVE
cve CPEMatch
cpeMatch) ->
                if VersionMatcher -> Text -> Bool
matchVersion (CPEMatch -> VersionMatcher
cpeMatchVersionMatcher CPEMatch
cpeMatch) Text
version
                  then
                    if CVE -> CPEMatch -> Text -> Text -> Bool
NVDRules.filter CVE
cve CPEMatch
cpeMatch Text
productID Text
version
                      then Text -> Maybe Text
forall a. a -> Maybe a
Just (CVE -> Text
cveID CVE
cve)
                      else Maybe Text
forall a. Maybe a
Nothing
                  else Maybe Text
forall a. Maybe a
Nothing
  [Text] -> (Text -> IO CVE) -> IO [CVE]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Text]
cveIDs ((Text -> IO CVE) -> IO [CVE]) -> (Text -> IO CVE) -> IO [CVE]
forall a b. (a -> b) -> a -> b
$ Connection -> Text -> IO CVE
getCVE Connection
conn

putCVEs :: Connection -> [CVE] -> IO ()
putCVEs :: Connection -> [CVE] -> IO ()
putCVEs Connection
conn [CVE]
cves = do
  Connection -> IO () -> IO ()
forall a. Connection -> IO a -> IO a
withTransaction Connection
conn (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Connection -> Query -> [Only Text] -> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany
      Connection
conn
      Query
"DELETE FROM cves WHERE cve_id = ?"
      ((CVE -> Only Text) -> [CVE] -> [Only Text]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> Only Text
forall a. a -> Only a
Only (Text -> Only Text) -> (CVE -> Text) -> CVE -> Only Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CVE -> Text
cveID) [CVE]
cves)
    Connection -> Query -> [CVE] -> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany
      Connection
conn
      ( Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
T.unlines
            [ Text
"INSERT INTO cves(cve_id, description, published, modified)",
              Text
"VALUES (?, ?, ?, ?)"
            ]
      )
      [CVE]
cves
    Connection -> Query -> [Only Text] -> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany
      Connection
conn
      Query
"DELETE FROM cpe_matches WHERE cve_id = ?"
      ((CVE -> Only Text) -> [CVE] -> [Only Text]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> Only Text
forall a. a -> Only a
Only (Text -> Only Text) -> (CVE -> Text) -> CVE -> Only Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CVE -> Text
cveID) [CVE]
cves)
    Connection -> Query -> [CPEMatchRow] -> IO ()
forall q. ToRow q => Connection -> Query -> [q] -> IO ()
executeMany
      Connection
conn
      ( Text -> Query
Query (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
T.unlines
            [ Text
"INSERT INTO cpe_matches(",
              Text
"  cve_id,",
              Text
"  part,",
              Text
"  vendor,",
              Text
"  product,",
              Text
"  version,",
              Text
"  \"update\",",
              Text
"  edition,",
              Text
"  language,",
              Text
"  software_edition,",
              Text
"  target_software,",
              Text
"  target_hardware,",
              Text
"  other,",
              Text
"  matcher)",
              Text
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
            ]
      )
      ([CVE] -> [CPEMatchRow]
cpeMatches [CVE]
cves)

getDBMeta :: Connection -> IO (DBVersion, UTCTime)
getDBMeta :: Connection -> IO (DBVersion, UTCTime)
getDBMeta Connection
conn = do
  [(DBVersion, UTCTime)]
rows <- Connection -> Query -> () -> IO [(DBVersion, UTCTime)]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
query Connection
conn Query
"SELECT db_version, last_update FROM meta" ()
  case [(DBVersion, UTCTime)]
rows of
    [(DBVersion, UTCTime)
meta] -> (DBVersion, UTCTime) -> IO (DBVersion, UTCTime)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DBVersion, UTCTime)
meta
    [(DBVersion, UTCTime)]
_ -> FilePath -> IO (DBVersion, UTCTime)
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"failed to get meta information"

needsRebuild :: IO Bool
needsRebuild :: IO Bool
needsRebuild = do
  Either SomeException (DBVersion, UTCTime)
dbMeta <- IO (DBVersion, UTCTime)
-> IO (Either SomeException (DBVersion, UTCTime))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (DBVersion, UTCTime)
 -> IO (Either SomeException (DBVersion, UTCTime)))
-> IO (DBVersion, UTCTime)
-> IO (Either SomeException (DBVersion, UTCTime))
forall a b. (a -> b) -> a -> b
$ (Connection -> IO (DBVersion, UTCTime)) -> IO (DBVersion, UTCTime)
forall a. (Connection -> IO a) -> IO a
withDB Connection -> IO (DBVersion, UTCTime)
getDBMeta
  UTCTime
currentTime <- IO UTCTime
getCurrentTime
  case Either SomeException (DBVersion, UTCTime)
dbMeta of
    Left (SomeException
e :: SomeException) -> do
      FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"rebuilding database because " FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> SomeException -> FilePath
forall a. Show a => a -> FilePath
show SomeException
e
      Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
    Right (DBVersion
dbVersion, UTCTime
t) ->
      Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$
        UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
currentTime UTCTime
t NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> (NominalDiffTime
7.5 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
nominalDay)
          Bool -> Bool -> Bool
|| DBVersion
dbVersion DBVersion -> DBVersion -> Bool
forall a. Eq a => a -> a -> Bool
/= DBVersion
softwareVersion

-- | Download a feed and store it in the database.
updateFeed :: Connection -> FeedID -> IO ()
updateFeed :: Connection -> FilePath -> IO ()
updateFeed Connection
conn FilePath
feedID = do
  FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Updating National Vulnerability Database feed (" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
feedID FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
")"
  ByteString
json <- FilePath -> IO ByteString
downloadFeed FilePath
feedID
  [CVE]
parsedCVEs <- (Text -> IO [CVE])
-> ([CVE] -> IO [CVE]) -> Either Text [CVE] -> IO [CVE]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> IO [CVE]
forall a. Text -> IO a
throwText [CVE] -> IO [CVE]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Text [CVE] -> IO [CVE]) -> Either Text [CVE] -> IO [CVE]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either Text [CVE]
parseFeed ByteString
json
  Connection -> [CVE] -> IO ()
putCVEs Connection
conn [CVE]
parsedCVEs

-- | Update the vulnerability database and run an action with a connection to
-- it.
withVulnDB :: (Connection -> IO a) -> IO a
withVulnDB :: (Connection -> IO a) -> IO a
withVulnDB Connection -> IO a
action = do
  Bool
rebuild <- IO Bool
needsRebuild
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
rebuild IO ()
rebuildDB
  (Connection -> IO a) -> IO a
forall a. (Connection -> IO a) -> IO a
withDB ((Connection -> IO a) -> IO a) -> (Connection -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Connection
conn -> do
    (DBVersion
_, UTCTime
lastUpdate) <- (Connection -> IO (DBVersion, UTCTime)) -> IO (DBVersion, UTCTime)
forall a. (Connection -> IO a) -> IO a
withDB Connection -> IO (DBVersion, UTCTime)
getDBMeta
    UTCTime
currentTime <- IO UTCTime
getCurrentTime
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
currentTime UTCTime
lastUpdate NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> (NominalDiffTime
0.25 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
nominalDay)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Connection -> FilePath -> IO ()
updateFeed Connection
conn FilePath
"modified"
      Connection -> IO ()
markUpdated Connection
conn
    Connection -> IO a
action Connection
conn

-- | Update a feed if it's older than a maximum age and return the contents as
-- ByteString.
downloadFeed :: FeedID -> IO BSL.ByteString
downloadFeed :: FilePath -> IO ByteString
downloadFeed FilePath
feed = do
  Meta UTCTime
_ ByteString
expectedChecksum <- FilePath -> IO Meta
getMeta FilePath
feed
  ByteString
compressed <- FilePath -> IO ByteString
forall (m :: * -> *). MonadIO m => FilePath -> m ByteString
simpleHttp (FilePath -> IO ByteString) -> FilePath -> IO ByteString
forall a b. (a -> b) -> a -> b
$ FilePath -> FilePath -> FilePath
feedURL FilePath
feed FilePath
".json.gz"
  let raw :: ByteString
raw = ByteString -> ByteString
decompress ByteString
compressed
  let actualChecksum :: ByteString
actualChecksum = ByteString -> ByteString
BSL.fromStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
hashlazy ByteString
raw
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
actualChecksum ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
expectedChecksum) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    FilePath -> IO ()
forall a. FilePath -> IO a
throwString (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$
      FilePath
"wrong hash, expected: "
        FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> ByteString -> FilePath
BSL.unpack (ByteString -> ByteString
forall t. Hex t => t -> t
hex ByteString
expectedChecksum)
        FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
" got: "
        FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> ByteString -> FilePath
BSL.unpack (ByteString -> ByteString
forall t. Hex t => t -> t
hex ByteString
actualChecksum)
  ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
raw