{-# LANGUAGE OverloadedStrings #-}
module NLP.Types.Tree where

import Prelude hiding (print)
import Control.Applicative ((<$>), (<*>))
import Data.String (IsString(..))
import Data.Text (Text)
import qualified Data.Text as T

import Test.QuickCheck (Arbitrary(..), listOf, elements, NonEmptyList(..))
import Test.QuickCheck.Instances ()

import NLP.Types.Tags

-- | A sentence of tokens without tags.  Generated by the tokenizer.
-- (tokenizer :: Text -> Sentence)
data Sentence = Sent [Token]
  deriving (Read, Show, Eq)

instance Arbitrary Sentence where
  arbitrary = Sent <$> arbitrary

-- | Extract the token list from a 'Sentence'
tokens :: Sentence -> [Token]
tokens (Sent ts) = ts

-- | Apply a parallel list of 'Tag's to a 'Sentence'.
applyTags :: Tag t => Sentence -> [t] -> TaggedSentence t
applyTags (Sent ts) tags = TaggedSent $ zipWith POS tags ts

-- | A chunked sentence has POS tags and chunk tags. Generated by a
-- chunker.
--
-- (chunker :: (Chunk chunk, Tag tag) => TaggedSentence tag -> ChunkedSentence chunk tag)
data ChunkedSentence chunk tag = ChunkedSent [ChunkOr chunk tag]
  deriving (Read, Show, Eq)

-- | A data type to represent the portions of a parse tree for Chunks.
-- Note that this part of the parse tree could be a POS tag with no
-- chunk.
data ChunkOr chunk tag = Chunk_CN (Chunk chunk tag)
                       | POS_CN   (POS tag)
                         deriving (Read, Show, Eq)

-- | A Chunk that strictly contains chunks or POS tags.
data Chunk chunk tag = Chunk chunk [ChunkOr chunk tag]
  deriving (Read, Show, Eq)

showChunkedSent :: (ChunkTag c, Tag t) => ChunkedSentence c t -> Text
showChunkedSent (ChunkedSent cs) = T.intercalate " " (map showChunkOr cs)
  where
    showChunkOr (POS_CN    pos)         = printPOS pos
    showChunkOr (Chunk_CN (Chunk chunk cors)) =
      let front = T.concat ["[", fromChunk chunk]
          back = "]"
          bits = map showChunkOr cors
      in T.append (T.intercalate " " (front:bits)) back

instance (ChunkTag c, Arbitrary c, Arbitrary t, Tag t) =>
  Arbitrary (ChunkedSentence c t) where
  arbitrary = ChunkedSent <$> arbitrary

-- | A tagged sentence has POS Tags.  Generated by a part-of-speech
-- tagger. (tagger :: Tag tag => Sentence -> TaggedSentence tag)
data TaggedSentence tag = TaggedSent [POS tag]
  deriving (Read, Show, Eq)

instance (Arbitrary t, Tag t) => Arbitrary (TaggedSentence t) where
  arbitrary = TaggedSent <$> arbitrary

-- | Generate a Text representation of a TaggedSentence in the common
-- tagged format, eg:
--
-- > "the/at dog/nn jumped/vbd ./."
--
printTS :: Tag t => TaggedSentence t -> Text
printTS (TaggedSent ts) = T.intercalate " " $ map printPOS ts

-- | Remove the tags from a tagged sentence
stripTags :: Tag t => TaggedSentence t -> Sentence
stripTags ts = fst $ unzipTags ts

-- | Extract the tags from a tagged sentence, returning a parallel
-- list of tags along with the underlying Sentence.
unzipTags :: Tag t => TaggedSentence t -> (Sentence, [t])
unzipTags (TaggedSent ts) =
  let (tags, toks) = unzip $ map topair ts
      topair (POS tag tok) = (tag, tok)
  in (Sent toks, tags)

unzipChunks :: (ChunkTag c, Tag t) => ChunkedSentence c t -> (TaggedSentence t, [c])
unzipChunks (ChunkedSent cs) = (TaggedSent poss, chunks)
  where
    (poss, chunks) = unzip (concatMap f cs)

--    f :: ChunkOr chunk tag -> [(POS tag, chunk)]
    f (POS_CN                  postag) = [(postag, notChunk)]
    f (Chunk_CN (Chunk chTag subTree)) = map (updateChunk chTag) (concatMap f subTree)

--    updateChunk :: c -> (POS t, c) -> (POS t, c)
    updateChunk chunk (ptag, oldChunk) | oldChunk == notChunk = (ptag, chunk)
                                       | otherwise            = (ptag, oldChunk)


-- | Combine the results of POS taggers, using the second param to
-- fill in 'tagUNK' entries, where possible.
combine :: Tag t => [TaggedSentence t] -> [TaggedSentence t] -> [TaggedSentence t]
combine xs ys = zipWith combineSentences xs ys

-- | Merge 'TaggedSentence' values, preffering the tags in the first 'TaggedSentence'.
-- Delegates to 'pickTag'.
combineSentences :: Tag t => TaggedSentence t -> TaggedSentence t -> TaggedSentence t
combineSentences (TaggedSent xs) (TaggedSent ys) = TaggedSent $ zipWith pickTag xs ys

-- | Returns the first param, unless it is tagged 'tagUNK'.
-- Throws an error if the text does not match.
pickTag :: Tag t => POS t -> POS t -> POS t
pickTag a@(POS t1 txt1) b@(POS t2 txt2)
  | txt1 /= txt2 = error ("Text does not match: "++ show a ++ " " ++ show b)
  | t1 /= tagUNK = POS t1 txt1
  | otherwise    = POS t2 txt1


instance (ChunkTag c, Arbitrary c, Arbitrary t, Tag t) => Arbitrary (ChunkOr c t) where
  arbitrary = elements =<< do
                chunk <- mkChunk <$> arbitrary <*> listOf arbitrary
                chink <- mkChink <$> arbitrary <*> arbitrary
                return [chunk, chink]

-- | Helper to create 'ChunkOr' types.
mkChunk :: (ChunkTag chunk, Tag tag) => chunk -> [ChunkOr chunk tag] -> ChunkOr chunk tag
mkChunk chunk children = Chunk_CN (Chunk chunk children)

-- | Helper to create 'ChunkOr' types that just hold POS tagged data.
mkChink :: (ChunkTag chunk, Tag tag) => tag -> Token -> ChunkOr chunk tag
mkChink tag token      = POS_CN (POS tag token)


instance (ChunkTag c, Arbitrary c, Arbitrary t, Tag t) => Arbitrary (Chunk c t) where
  arbitrary = Chunk <$> arbitrary <*> arbitrary

-- | A POS-tagged token.
data POS tag = POS { posTag :: tag
                   , posToken :: Token
                   } deriving (Read, Show, Eq)

instance (Arbitrary t, Tag t) => Arbitrary (POS t) where
  arbitrary = POS <$> arbitrary <*> arbitrary

-- | Show the underlying text token only.
showPOStok :: Tag tag => POS tag -> Text
showPOStok (POS _ (Token txt)) = txt

showPOStag :: Tag tag => POS tag -> Text
showPOStag = tagTerm . posTag

-- | Show the text and tag.
printPOS :: Tag tag => POS tag -> Text
printPOS (POS tag (Token txt)) = T.intercalate "" [txt, "/", tagTerm tag]


-- | Raw tokenized text.
--
-- 'Token' has a 'IsString' instance to simplify use.
data Token = Token Text
  deriving (Read, Show, Eq)

instance Arbitrary Token where
  arbitrary = do NonEmpty txt <- arbitrary
                 return $ Token (T.pack txt)

instance IsString Token where
  fromString = Token . T.pack

-- | Extract the text of a 'Token'
showTok :: Token -> Text
showTok (Token txt) = txt

-- | Extract the last three characters of a 'Token', if the token is
-- long enough, otherwise returns the full token text.
suffix :: Token -> Text
suffix (Token str) | T.length str <= 3 = str
                   | otherwise         = T.drop (T.length str - 3) str

-- | Extract the list of 'POS' tags from a 'TaggedSentence'
unTS :: Tag t => TaggedSentence t -> [POS t]
unTS (TaggedSent ts) = ts

-- | Calculate the length of a 'TaggedSentence' (in terms of the
-- number of tokens).
tsLength :: Tag t => TaggedSentence t -> Int
tsLength (TaggedSent ts) = length ts

-- | Brutally concatenate two 'TaggedSentence's
tsConcat :: Tag t => [TaggedSentence t] -> TaggedSentence t
tsConcat tss = TaggedSent (concatMap unTS tss)

-- | True if the input sentence contains the given text token.  Does
-- not do partial or approximate matching, and compares details in a
-- fully case-sensitive manner.
contains :: Tag t => TaggedSentence t -> Text -> Bool
contains (TaggedSent ts) tok = any (posTokMatches tok) ts

-- | True if the input sentence contains the given POS tag.
-- Does not do partial matching (such as prefix matching)
containsTag :: Tag t => TaggedSentence t -> t -> Bool
containsTag (TaggedSent ts) tag = any (posTagMatches tag) ts

-- | Compare the POS-tag token with a supplied tag string.
posTagMatches :: Tag t => t -> POS t -> Bool
posTagMatches t1 (POS t2 _) = t1 == t2

-- | Compare the POS-tagged token with a text string.
posTokMatches :: Tag t => Text -> POS t -> Bool
posTokMatches txt (POS _ tok) = tokenMatches txt tok

-- | Compare a token with a text string.
tokenMatches :: Text -> Token -> Bool
tokenMatches txt (Token tok) = txt == tok