{-# LANGUAGE GADTs #-}
-- | Decode postgresql composites into entities
-- Adapted code from: https://hackage.haskell.org/package/postgresql-simple-0.4.10.0/docs/src/Database-PostgreSQL-Simple-Arrays.html
module Internal.Composite where

import Internal.Interlude hiding (show, option) 
import Prelude (Show(..))
import Database.PostgreSQL.Simple.FromField
import Data.Typeable
import qualified Database.PostgreSQL.Simple.TypeInfo as TI
import Data.Attoparsec.ByteString.Char8 hiding (Result)
import Control.Applicative ((<|>), many)
import qualified Data.ByteString.Char8 as B
import Data.Foldable (toList)

data Composite (ts :: [*]) where
    EmptyComposite :: Composite '[]
    ConsComposite :: t -> Composite ts -> Composite (t ': ts)

instance Show (Composite '[]) where
    show EmptyComposite = "EmptyComposite"
instance (Show (Composite ts), Show t) => Show (Composite (t ': ts)) where
    show (ConsComposite t r) = "ConsComposite " <> show t <> " (" <> show r <> ")"

-- | any postgresql composite type whose fields are compatible with types @ts@
instance FieldParsers ts => FromField (Composite ts) where
    fromField = pgCompositeFieldParser

pgCompositeFieldParser :: FieldParsers ts => FieldParser (Composite ts)
pgCompositeFieldParser f mdat = do
    info <- typeInfo f
    let cont = case mdat of
            Nothing  -> returnError UnexpectedNull f ""
            Just dat ->
                case parseOnly (fromComposite info f) dat of
                    Left  err  -> returnError ConversionFailed f err
                    Right conv -> conv
    case info of
        TI.Composite{} -> cont
        TI.Basic{typname = "composite"} -> cont
        _ -> returnError Incompatible f ("TypeInfo: " <> show info)

class Typeable ts => FieldParsers ts where
    fromCompositeFormats :: [TypeInfo] -> Field -> [CompositeFormat] -> Conversion (Composite ts)
instance FieldParsers '[] where
    fromCompositeFormats [] _ [] = return EmptyComposite
    fromCompositeFormats _ f _ = returnError Incompatible f "The Composite's type indicates a smaller number of elements than the composite that was received"
instance (FromField t, Typeable t, FieldParsers ts) => FieldParsers (t ': ts) where
    fromCompositeFormats (ti : tis) f (af : afs) = 
        ConsComposite 
            <$> fromField @t fElem (if af == NullStr then Nothing else Just item')
            <*> fromCompositeFormats @ts tis f afs
        where
        fElem = f { typeOid = typoid ti }
        item' = fmt af
    fromCompositeFormats _ f _ = returnError Incompatible f "The Composite's type indicates a greater number of elements than the composite that was received"

fromComposite :: FieldParsers ts => TypeInfo -> Field -> Parser (Conversion (Composite ts))
fromComposite ti f = fromCompositeFormats elems f <$> composite
    where
    elems = toList . fmap atttype . attributes $ ti

compositeFormat :: Parser CompositeFormat
compositeFormat = 
    Plain <$> plain
    <|> Quoted <$> quoted

data CompositeFormat = 
    Plain B.ByteString
    | Quoted B.ByteString
    | NullStr
    deriving (Eq, Show, Ord)

composite :: Parser [CompositeFormat]
composite = char '(' *> option [] strings <* char ')'
    where
    strings = sepBy1 (Quoted <$> quoted <|> Plain <$> plain <|> return NullStr) (char ',')

quoted :: Parser B.ByteString
quoted  = char '"' *> option "" contents <* char '"'
    where
    esc' = (char '\\' *> char '\\')
       <|> (char '"' *> char '"')
    unQ = takeWhile1 (notInClass "\"\\")
    contents = mconcat <$> many (unQ <|> B.singleton <$> esc')

plain :: Parser B.ByteString
plain = takeWhile1 (notInClass ",\"() ")

fmt :: CompositeFormat -> B.ByteString
fmt = fmt' False

delimit :: [CompositeFormat] -> B.ByteString
delimit [] = ""
delimit [x] = fmt' True x
delimit (x:y:z) = (fmt' True x `B.snoc` ',') `mappend` delimit (y:z)

fmt' :: Bool -> CompositeFormat -> B.ByteString
fmt' quoting x = case x of
    Plain bytes          -> B.copy bytes
    Quoted q 
        | quoting   -> '"' `B.cons` (esc q `B.snoc` '"')
        | otherwise -> B.copy q
    NullStr -> ""

esc :: B.ByteString -> B.ByteString
esc = B.concatMap f
    where
    f '"'  = "\\\""
    f '\\' = "\\\\"
    f c    = B.singleton c