{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module DataFrame.IO.Parquet where

import Control.Exception (throw)
import Control.Monad
import Data.Bits
import qualified Data.ByteString as BSO
import Data.Either
import Data.IORef
import Data.Int
import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import Data.Text.Encoding
import Data.Time
import Data.Time.Clock.POSIX (posixSecondsToUTCTime)
import Data.Word
import DataFrame.Errors (DataFrameException (ColumnNotFoundException))
import qualified DataFrame.Internal.Column as DI
import DataFrame.Internal.DataFrame (DataFrame)
import DataFrame.Internal.Expression (Expr, getColumns)
import qualified DataFrame.Operations.Core as DI
import DataFrame.Operations.Merge ()
import qualified DataFrame.Operations.Subset as DS
import System.FilePath.Glob (glob)

import DataFrame.IO.Parquet.Dictionary
import DataFrame.IO.Parquet.Levels
import DataFrame.IO.Parquet.Page
import DataFrame.IO.Parquet.Thrift
import DataFrame.IO.Parquet.Types
import System.Directory (doesDirectoryExist)

import qualified Data.Vector.Unboxed as VU
import System.FilePath ((</>))

data ParquetReadOptions = ParquetReadOptions
    { selectedColumns :: Maybe [T.Text]
    , predicate :: Maybe (Expr Bool)
    , rowRange :: Maybe (Int, Int)
    }
    deriving (Eq, Show)

defaultParquetReadOptions :: ParquetReadOptions
defaultParquetReadOptions =
    ParquetReadOptions
        { selectedColumns = Nothing
        , predicate = Nothing
        , rowRange = Nothing
        }

{- | Read a parquet file from path and load it into a dataframe.

==== __Example__
@
ghci> D.readParquet ".\/data\/mtcars.parquet"
@
-}
readParquet :: FilePath -> IO DataFrame
readParquet = readParquetWithOpts defaultParquetReadOptions

readParquetWithOpts :: ParquetReadOptions -> FilePath -> IO DataFrame
readParquetWithOpts opts path = do
    fileMetadata <- readMetadataFromPath path
    let columnPaths = getColumnPaths (drop 1 $ schema fileMetadata)
    let columnNames = map fst columnPaths
    let leafNames = map (last . T.splitOn ".") columnNames
    let availableSelectedColumns = L.nub leafNames
    let predicateColumns = maybe [] (L.nub . getColumns) (predicate opts)
    let selectedColumnsForRead = case selectedColumns opts of
            Nothing -> Nothing
            Just selected -> Just (L.nub (selected ++ predicateColumns))
    let selectedColumnSet = S.fromList <$> selectedColumnsForRead
    let shouldReadColumn colName _ =
            case selectedColumnSet of
                Nothing -> True
                Just selected -> colName `S.member` selected

    case selectedColumnsForRead of
        Nothing -> pure ()
        Just requested ->
            let missing = requested L.\\ availableSelectedColumns
             in unless
                    (L.null missing)
                    ( throw
                        ( ColumnNotFoundException
                            (T.pack $ show missing)
                            "readParquetWithOpts"
                            availableSelectedColumns
                        )
                    )

    colMap <- newIORef (M.empty :: M.Map T.Text DI.Column)
    lTypeMap <- newIORef (M.empty :: M.Map T.Text LogicalType)

    contents <- BSO.readFile path

    let schemaElements = schema fileMetadata
    let getTypeLength :: [String] -> Maybe Int32
        getTypeLength path = findTypeLength schemaElements path 0
          where
            findTypeLength [] _ _ = Nothing
            findTypeLength (s : ss) targetPath depth
                | map T.unpack (pathToElement s ss depth) == targetPath
                    && elementType s == STRING
                    && typeLength s > 0 =
                    Just (typeLength s)
                | otherwise =
                    findTypeLength ss targetPath (if numChildren s > 0 then depth + 1 else depth)

            pathToElement _ _ _ = []

    forM_ (rowGroups fileMetadata) $ \rowGroup -> do
        forM_ (zip (rowGroupColumns rowGroup) [0 ..]) $ \(colChunk, colIdx) -> do
            let metadata = columnMetaData colChunk
            let colPath = columnPathInSchema metadata
            let colName =
                    if null colPath
                        then T.pack $ "col_" ++ show colIdx
                        else T.pack $ last colPath

            when (shouldReadColumn colName colPath) $ do
                let colDataPageOffset = columnDataPageOffset metadata
                let colDictionaryPageOffset = columnDictionaryPageOffset metadata
                let colStart =
                        if colDictionaryPageOffset > 0 && colDataPageOffset > colDictionaryPageOffset
                            then colDictionaryPageOffset
                            else colDataPageOffset
                let colLength = columnTotalCompressedSize metadata

                let columnBytes = BSO.take (fromIntegral colLength) (BSO.drop (fromIntegral colStart) contents)

                pages <- readAllPages (columnCodec metadata) columnBytes

                let maybeTypeLength =
                        if columnType metadata == PFIXED_LEN_BYTE_ARRAY
                            then getTypeLength colPath
                            else Nothing

                let primaryEncoding = maybe EPLAIN fst (L.uncons (columnEncodings metadata))

                let schemaTail = drop 1 (schema fileMetadata)
                let (maxDef, maxRep) = levelsForPath schemaTail colPath
                let lType = logicalType (schemaTail !! colIdx)
                column <-
                    processColumnPages
                        (maxDef, maxRep)
                        pages
                        (columnType metadata)
                        primaryEncoding
                        maybeTypeLength
                        lType

                modifyIORef colMap (M.insertWith DI.concatColumnsEither colName column)
                modifyIORef lTypeMap (M.insert colName lType)

    finalColMap <- readIORef colMap
    finalLTypeMap <- readIORef lTypeMap
    let orderedColumns =
            map
                ( \name ->
                    ( name
                    , applyLogicalType (finalLTypeMap M.! name) $ finalColMap M.! name
                    )
                )
                (filter (`M.member` finalColMap) columnNames)

    pure $ applyReadOptions opts (DI.fromNamedColumns orderedColumns)

readParquetFiles :: FilePath -> IO DataFrame
readParquetFiles = readParquetFilesWithOpts defaultParquetReadOptions

readParquetFilesWithOpts :: ParquetReadOptions -> FilePath -> IO DataFrame
readParquetFilesWithOpts opts path = do
    isDir <- doesDirectoryExist path

    let pat = if isDir then path </> "*" else path

    matches <- glob pat

    files <- filterM (fmap not . doesDirectoryExist) matches

    case files of
        [] ->
            error $
                "readParquetFiles: no parquet files found for " ++ path
        _ -> do
            let optsWithoutRowRange = opts{rowRange = Nothing}
            dfs <- mapM (readParquetWithOpts optsWithoutRowRange) files
            pure (applyRowRange opts (mconcat dfs))

applyRowRange :: ParquetReadOptions -> DataFrame -> DataFrame
applyRowRange opts df =
    maybe df (`DS.range` df) (rowRange opts)

applySelectedColumns :: ParquetReadOptions -> DataFrame -> DataFrame
applySelectedColumns opts df =
    maybe df (`DS.select` df) (selectedColumns opts)

applyPredicate :: ParquetReadOptions -> DataFrame -> DataFrame
applyPredicate opts df =
    maybe df (`DS.filterWhere` df) (predicate opts)

applyReadOptions :: ParquetReadOptions -> DataFrame -> DataFrame
applyReadOptions opts =
    applyRowRange opts
        . applySelectedColumns opts
        . applyPredicate opts

readMetadataFromPath :: FilePath -> IO FileMetadata
readMetadataFromPath path = do
    contents <- BSO.readFile path
    let (size, magicString) = contents `seq` readMetadataSizeFromFooter contents
    when (magicString /= "PAR1") $ error "Invalid Parquet file"
    readMetadata contents size

readMetadataSizeFromFooter :: BSO.ByteString -> (Int, BSO.ByteString)
readMetadataSizeFromFooter contents =
    let
        footerOffSet = BSO.length contents - 8
        sizeBytes =
            map
                (fromIntegral @Word8 @Int32 . BSO.index contents)
                [footerOffSet .. footerOffSet + 3]
        size = fromIntegral $ L.foldl' (.|.) 0 $ zipWith shift sizeBytes [0, 8, 16, 24]
        magicStringBytes = map (BSO.index contents) [footerOffSet + 4 .. footerOffSet + 7]
        magicString = BSO.pack magicStringBytes
     in
        (size, magicString)

getColumnPaths :: [SchemaElement] -> [(T.Text, Int)]
getColumnPaths schema = extractLeafPaths schema 0 []
  where
    extractLeafPaths :: [SchemaElement] -> Int -> [T.Text] -> [(T.Text, Int)]
    extractLeafPaths [] _ _ = []
    extractLeafPaths (s : ss) idx path
        | numChildren s == 0 =
            let fullPath = T.intercalate "." (path ++ [elementName s])
             in (fullPath, idx) : extractLeafPaths ss (idx + 1) path
        | otherwise =
            let newPath = if T.null (elementName s) then path else path ++ [elementName s]
                childrenCount = fromIntegral (numChildren s)
                (children, remaining) = splitAt childrenCount ss
                childResults = extractLeafPaths children idx newPath
             in childResults ++ extractLeafPaths remaining (idx + length childResults) path

processColumnPages ::
    (Int, Int) ->
    [Page] ->
    ParquetType ->
    ParquetEncoding ->
    Maybe Int32 ->
    LogicalType ->
    IO DI.Column
processColumnPages (maxDef, maxRep) pages pType _ maybeTypeLength lType = do
    let dictPages = filter isDictionaryPage pages
    let dataPages = filter isDataPage pages

    let dictValsM =
            case dictPages of
                [] -> Nothing
                (dictPage : _) ->
                    case pageTypeHeader (pageHeader dictPage) of
                        DictionaryPageHeader{..} ->
                            let countForBools =
                                    if pType == PBOOLEAN
                                        then error "is bool" Just dictionaryPageHeaderNumValues
                                        else maybeTypeLength
                             in Just (readDictVals pType (pageBytes dictPage) countForBools)
                        _ -> Nothing

    cols <- forM dataPages $ \page -> do
        case pageTypeHeader (pageHeader page) of
            DataPageHeader{..} -> do
                let n = fromIntegral dataPageHeaderNumValues
                let bs0 = pageBytes page
                let (defLvls, _repLvls, afterLvls) = readLevelsV1 n maxDef maxRep bs0
                let nPresent = length (filter (== maxDef) defLvls)

                case dataPageHeaderEncoding of
                    EPLAIN ->
                        case pType of
                            PBOOLEAN ->
                                let (vals, _) = readNBool nPresent afterLvls
                                 in pure (toMaybeBool maxDef defLvls vals)
                            PINT32 ->
                                let (vals, _) = readNInt32 nPresent afterLvls
                                 in pure (toMaybeInt32 maxDef defLvls vals)
                            PINT64 ->
                                let (vals, _) = readNInt64 nPresent afterLvls
                                 in pure (toMaybeInt64 maxDef defLvls vals)
                            PINT96 ->
                                let (vals, _) = readNInt96Times nPresent afterLvls
                                 in pure (toMaybeUTCTime maxDef defLvls vals)
                            PFLOAT ->
                                let (vals, _) = readNFloat nPresent afterLvls
                                 in pure (toMaybeFloat maxDef defLvls vals)
                            PDOUBLE ->
                                let (vals, _) = readNDouble nPresent afterLvls
                                 in pure (toMaybeDouble maxDef defLvls vals)
                            PBYTE_ARRAY ->
                                let (raws, _) = readNByteArrays nPresent afterLvls
                                    texts = map decodeUtf8 raws
                                 in pure (toMaybeText maxDef defLvls texts)
                            PFIXED_LEN_BYTE_ARRAY ->
                                case maybeTypeLength of
                                    Just len ->
                                        let (raws, _) = splitFixed nPresent (fromIntegral len) afterLvls
                                            texts = map decodeUtf8 raws
                                         in pure (toMaybeText maxDef defLvls texts)
                                    Nothing -> error "FIXED_LEN_BYTE_ARRAY requires type length"
                            PARQUET_TYPE_UNKNOWN -> error "Cannot read unknown Parquet type"
                    ERLE_DICTIONARY -> decodeDictV1 dictValsM maxDef defLvls nPresent afterLvls
                    EPLAIN_DICTIONARY -> decodeDictV1 dictValsM maxDef defLvls nPresent afterLvls
                    other -> error ("Unsupported v1 encoding: " ++ show other)
            DataPageHeaderV2{..} -> do
                let n = fromIntegral dataPageHeaderV2NumValues
                let bs0 = pageBytes page
                let (defLvls, _repLvls, afterLvls) =
                        readLevelsV2
                            n
                            maxDef
                            maxRep
                            definitionLevelByteLength
                            repetitionLevelByteLength
                            bs0
                let nPresent =
                        if dataPageHeaderV2NumNulls > 0
                            then fromIntegral (dataPageHeaderV2NumValues - dataPageHeaderV2NumNulls)
                            else length (filter (== maxDef) defLvls)

                case dataPageHeaderV2Encoding of
                    EPLAIN ->
                        case pType of
                            PBOOLEAN ->
                                let (vals, _) = readNBool nPresent afterLvls
                                 in pure (toMaybeBool maxDef defLvls vals)
                            PINT32 ->
                                let (vals, _) = readNInt32 nPresent afterLvls
                                 in pure (toMaybeInt32 maxDef defLvls vals)
                            PINT64 ->
                                let (vals, _) = readNInt64 nPresent afterLvls
                                 in pure (toMaybeInt64 maxDef defLvls vals)
                            PINT96 ->
                                let (vals, _) = readNInt96Times nPresent afterLvls
                                 in pure (toMaybeUTCTime maxDef defLvls vals)
                            PFLOAT ->
                                let (vals, _) = readNFloat nPresent afterLvls
                                 in pure (toMaybeFloat maxDef defLvls vals)
                            PDOUBLE ->
                                let (vals, _) = readNDouble nPresent afterLvls
                                 in pure (toMaybeDouble maxDef defLvls vals)
                            PBYTE_ARRAY ->
                                let (raws, _) = readNByteArrays nPresent afterLvls
                                    texts = map decodeUtf8 raws
                                 in pure (toMaybeText maxDef defLvls texts)
                            PFIXED_LEN_BYTE_ARRAY ->
                                case maybeTypeLength of
                                    Just len ->
                                        let (raws, _) = splitFixed nPresent (fromIntegral len) afterLvls
                                            texts = map decodeUtf8 raws
                                         in pure (toMaybeText maxDef defLvls texts)
                                    Nothing -> error "FIXED_LEN_BYTE_ARRAY requires type length"
                            PARQUET_TYPE_UNKNOWN -> error "Cannot read unknown Parquet type"
                    ERLE_DICTIONARY -> decodeDictV1 dictValsM maxDef defLvls nPresent afterLvls
                    EPLAIN_DICTIONARY -> decodeDictV1 dictValsM maxDef defLvls nPresent afterLvls
                    other -> error ("Unsupported v2 encoding: " ++ show other)
            -- Cannot happen as these are filtered out by isDataPage above
            DictionaryPageHeader{} -> error "processColumnPages: impossible DictionaryPageHeader"
            INDEX_PAGE_HEADER -> error "processColumnPages: impossible INDEX_PAGE_HEADER"
            PAGE_TYPE_HEADER_UNKNOWN -> error "processColumnPages: impossible PAGE_TYPE_HEADER_UNKNOWN"
    -- This is N^2. We should probably use mutable columns here.
    case cols of
        [] -> pure $ DI.fromList ([] :: [Maybe Int])
        (c : cs) ->
            pure $
                L.foldl' (\l r -> fromRight (error "concat failed") (DI.concatColumns l r)) c cs

applyLogicalType :: LogicalType -> DI.Column -> DI.Column
applyLogicalType (TimestampType _ unit) col =
    fromRight col $
        DI.mapColumn
            (microsecondsToUTCTime . (* (1_000_000 `div` unitDivisor unit)))
            col
applyLogicalType (DecimalType precision scale) col
    | precision <= 9 = case DI.toVector @Int32 @VU.Vector col of
        Right xs ->
            DI.fromUnboxedVector $
                VU.map (\raw -> fromIntegral @Int32 @Double raw / 10 ^ scale) xs
        Left _ -> col
    | precision <= 18 = case DI.toVector @Int64 @VU.Vector col of
        Right xs ->
            DI.fromUnboxedVector $
                VU.map (\raw -> fromIntegral @Int64 @Double raw / 10 ^ scale) xs
        Left _ -> col
    | otherwise = col
applyLogicalType _ col = col

microsecondsToUTCTime :: Int64 -> UTCTime
microsecondsToUTCTime us =
    posixSecondsToUTCTime (fromIntegral us / 1_000_000)

unitDivisor :: TimeUnit -> Int64
unitDivisor MILLISECONDS = 1_000
unitDivisor MICROSECONDS = 1_000_000
unitDivisor NANOSECONDS = 1_000_000_000
unitDivisor TIME_UNIT_UNKNOWN = 1

applyScale :: Int32 -> Int32 -> Double
applyScale scale rawValue =
    fromIntegral rawValue / (10 ^ scale)
