#!/usr/bin/env python3
import os
import sys
import json
import re
import logging
import tempfile
from typing import Dict, List, Tuple
import pandas as pd
import yaml
import nltk
from nltk.tokenize import sent_tokenize as nltk_sent_tokenize, word_tokenize
from pypdf import PdfReader
from docx import Document
import win32com.client
import pythoncom
from hazm import Normalizer, sent_tokenize as hazm_sent_tokenize
import arabic_reshaper
from bidi.algorithm import get_display

# Download required NLTK data once at startup
nltk_resources = ["punkt", "averaged_perceptron_tagger", "stopwords"]
[nltk.download(resource) for resource in nltk_resources]


class DocumentProcessor:
    # Supported MIME types
    SUPPORTED_TYPES = {
        "application/pdf": ".pdf",
        "application/msword": ".doc",
        "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
        "text/csv": ".csv",
        "application/vnd.ms-excel": ".xls",
        "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
        "text/plain": ".txt",
        "text/markdown": ".md",
        "application/json": ".json",
        "application/yaml": ".yml",
        "application/x-yaml": ".yaml",
    }

    def __init__(self):
        self.normalizer = Normalizer()
        # Initialize COM for Word and prepare the Word application on demand
        pythoncom.CoInitialize()
        self.word = None

    def __del__(self):
        if self.word:
            self.word.Quit()
        pythoncom.CoUninitialize()

    def convert_to_english_numbers(self, text: str) -> str:
        """Convert Persian/Arabic numbers to English numbers"""
        if not text:
            return ""
        number_map = {
            "۰": "0",
            "۱": "1",
            "۲": "2",
            "۳": "3",
            "۴": "4",
            "۵": "5",
            "۶": "6",
            "۷": "7",
            "۸": "8",
            "۹": "9",
            "،": ",",  # Persian comma conversion
        }
        converted = "".join(number_map.get(c, c) for c in text)
        converted = re.sub(r"(\d+)[\s,،]+(\d{3})", r"\1,\2", converted)
        return converted

    def clean_table_text(self, text: str) -> str:
        """Clean up text extracted from tables"""
        if not text:
            return ""
        text = text.replace("\n", " ")
        text = self.normalizer.normalize(text)
        return self.convert_to_english_numbers(text)

    def is_header(self, paragraph) -> bool:
        """Detect if a paragraph is likely a header based on formatting"""
        try:
            if hasattr(paragraph, "style"):
                if paragraph.style.name.startswith("Heading"):
                    return True
                if hasattr(paragraph, "runs"):
                    return any(run.bold for run in paragraph.runs)
            text = paragraph.text if hasattr(paragraph, "text") else str(paragraph)
            if text.strip().endswith((":", "：")):
                return True
            if len(text.strip()) < 50 and text.count("\n") > 1:
                return True
        except Exception:
            pass
        return False

    def extract_structure(self, paragraphs) -> List[Dict]:
        """Extract structured content with headers and their associated text"""
        structure = []
        current_header = None
        current_content = []

        for para in paragraphs:
            if self.is_header(para):
                if current_header:
                    structure.append(
                        {
                            "header": current_header,
                            "content": "\n".join(current_content),
                        }
                    )
                current_header = para.text if hasattr(para, "text") else str(para)
                current_content = []
            else:
                current_content.append(
                    para.text if hasattr(para, "text") else str(para)
                )

        if current_header:
            structure.append(
                {"header": current_header, "content": "\n".join(current_content)}
            )

        return structure

    def extract_sections(self, text: str) -> Dict[str, List[str]]:
        """Extract sections from text based on short header-like lines"""
        lines = text.split("\n")
        data = {}
        current_label = None

        for line in lines:
            line = line.strip()
            if not line or "•" in line or "o" in line:
                continue
            if len(line) < 50 and ":" not in line:
                current_label = line
                data[current_label] = []
            elif current_label:
                data[current_label].append(line)

        return data

    def detect_language(self, text: str) -> str:
        """Detect if text is primarily Persian or English"""
        persian_pattern = re.compile(r"[۰-۹آ-ی]")
        persian_count = len(persian_pattern.findall(text))
        english_pattern = re.compile(r"[a-zA-Z]")
        english_count = len(english_pattern.findall(text))
        return "fa" if persian_count > english_count else "en"

    def extract_keyphrases(self, text: str, language: str) -> List[str]:
        """Extract key phrases from text based on language"""
        text = self.convert_to_english_numbers(text)
        try:
            if language == "fa":
                words = text.split()
                phrases = [word for word in words if len(word) > 2]
                return phrases[:10]
            else:
                words = word_tokenize(text)
                tagged = nltk.pos_tag(words)
                phrases = []
                current_phrase = []
                for word, tag in tagged:
                    if tag.startswith(("NN", "JJ")):
                        current_phrase.append(word)
                    else:
                        if current_phrase:
                            phrases.append(" ".join(current_phrase))
                            current_phrase = []
                if current_phrase:
                    phrases.append(" ".join(current_phrase))
                return phrases[:10]
        except Exception as e:
            raise Exception(f"Error extracting keyphrases: {str(e)}")

    def extract_doc_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text from DOC files using Word application"""
        try:
            if not self.word:
                self.word = win32com.client.Dispatch("Word.Application")
            self.word.Visible = False

            doc = self.word.Documents.Open(os.path.abspath(file_path))
            temp_docx = tempfile.mktemp(suffix=".docx")
            doc.SaveAs2(temp_docx, FileFormat=16)  # 16 = DOCX format
            doc.Close()

            docx = Document(temp_docx)
            text = "\n".join([paragraph.text for paragraph in docx.paragraphs])
            structure = self.extract_structure(docx.paragraphs)
            tables = []

            for idx, table in enumerate(docx.tables):
                headers = [
                    self.clean_table_text(cell.text.strip())
                    for cell in table.rows[0].cells
                ]
                data = []
                for row in table.rows[1:]:
                    row_data = [
                        self.clean_table_text(cell.text.strip()) for cell in row.cells
                    ]
                    if any(row_data):
                        data.append(row_data)
                if headers and data:
                    tables.append(
                        {"table_index": idx, "headers": headers, "data": data}
                    )
            os.unlink(temp_docx)
            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from DOC {file_path}: {str(e)}")

    def extract_docx_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text from DOCX documents"""
        try:
            doc = Document(file_path)
            text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
            structure = self.extract_structure(doc.paragraphs)
            tables = []

            for idx, table in enumerate(doc.tables):
                headers = [
                    self.clean_table_text(cell.text.strip())
                    for cell in table.rows[0].cells
                ]
                data = []
                for row in table.rows[1:]:
                    row_data = [
                        self.clean_table_text(cell.text.strip()) for cell in row.cells
                    ]
                    if any(row_data):
                        data.append(row_data)
                if headers and data:
                    tables.append(
                        {"table_index": idx, "headers": headers, "data": data}
                    )
            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from DOCX {file_path}: {str(e)}")

    def extract_pdf_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text and tables from PDF documents"""
        try:
            reader = PdfReader(file_path)
            text = ""
            tables = []
            structure = []

            for page in reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"

            try:
                import camelot

                pdf_tables = camelot.read_pdf(file_path, pages="all", flavor="stream")
                for idx, table in enumerate(pdf_tables):
                    df = table.df
                    if not df.empty and len(df) > 1:
                        headers = df.iloc[0].tolist()
                        data = df.iloc[1:].values.tolist()
                        processed_headers = []
                        for header in headers:
                            if isinstance(header, str):
                                header_clean = self.clean_table_text(header)
                                reshaped = arabic_reshaper.reshape(header_clean)
                                header = get_display(reshaped)
                            processed_headers.append(header)
                        processed_data = []
                        for row in data:
                            processed_row = []
                            for cell in row:
                                if isinstance(cell, str):
                                    cell_clean = self.clean_table_text(cell)
                                    reshaped = arabic_reshaper.reshape(cell_clean)
                                    cell = get_display(reshaped)
                                processed_row.append(cell)
                            processed_data.append(processed_row)
                        tables.append(
                            {
                                "table_index": idx,
                                "headers": processed_headers,
                                "data": processed_data,
                            }
                        )
            except Exception as e:
                print(f"Warning: Table extraction failed: {str(e)}")

            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from PDF {file_path}: {str(e)}")

    def extract_text(
        self, file_path: str, mime_type: str
    ) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text and tables from various file formats"""
        try:
            tables = []
            text = ""
            structure = []

            if mime_type == "application/pdf":
                text, tables, structure = self.extract_pdf_text(file_path)
            elif mime_type == "application/msword":
                text, tables, structure = self.extract_doc_text(file_path)
            elif (
                mime_type
                == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
            ):
                text, tables, structure = self.extract_docx_text(file_path)
            elif mime_type in [
                "text/csv",
                "application/vnd.ms-excel",
                "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
            ]:
                if mime_type == "text/csv":
                    df = pd.read_csv(file_path)
                    text = df.to_string()
                    if not df.empty:
                        tables.append(
                            {"headers": df.columns.tolist(), "data": df.values.tolist()}
                        )
                else:
                    xl = pd.ExcelFile(file_path)
                    for sheet_name in xl.sheet_names:
                        df = pd.read_excel(file_path, sheet_name=sheet_name)
                        text += f"\nSheet: {sheet_name}\n" + df.to_string()
                        if not df.empty:
                            tables.append(
                                {
                                    "sheet_name": sheet_name,
                                    "headers": df.columns.tolist(),
                                    "data": df.values.tolist(),
                                }
                            )
            elif mime_type in ["text/plain", "text/markdown"]:
                with open(file_path, "r", encoding="utf-8") as f:
                    text = f.read()
            elif mime_type == "application/json":
                with open(file_path, "r", encoding="utf-8") as f:
                    content = json.load(f)
                    text = json.dumps(content, ensure_ascii=False, indent=2)
            elif mime_type in ["application/yaml", "application/x-yaml"]:
                with open(file_path, "r", encoding="utf-8") as f:
                    content = yaml.safe_load(f)
                    text = yaml.dump(content, allow_unicode=True)
            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from {file_path}: {str(e)}")

    def process_file(self, input_path: str, mime_type: str) -> Dict:
        """Process a single file and return its processed content"""
        try:
            text, tables, structure = self.extract_text(input_path, mime_type)
            if not text.strip():
                return {
                    "file_name": os.path.basename(input_path),
                    "error": "No text content could be extracted",
                    "content": {
                        "language": None,
                        "sentences": [],
                        "tables": tables,
                        "structure": structure,
                    },
                    "display": {
                        "normalized_text": "",
                    },
                }

            normalized = self.convert_to_english_numbers(text)
            sections = self.extract_sections(normalized)
            language = self.detect_language(normalized)
            keyphrases = set()
            for section in structure:
                header_text = section.get("header", "")
                keyphrases.update(self.extract_keyphrases(header_text, language))
            keyphrases = list(keyphrases)[:10]

            # Convert text into sentences with better error handling
            try:
                sentences = (
                    nltk_sent_tokenize(normalized)
                    if language == "en"
                    else hazm_sent_tokenize(normalized)
                )
            except LookupError as e:
                logging.error(f"NLTK resource error: {str(e)}")
                logging.info("Falling back to basic sentence splitting")
                # Simple fallback using periods for sentence splitting
                sentences = [s.strip() for s in normalized.split(".") if s.strip()]
            except Exception as e:
                logging.error(f"Sentence tokenization error: {str(e)}")
                sentences = []

            # Cluster sentences under key phrases
            keyphrase_clusters = {}
            for kp in keyphrases:
                keyphrase_clusters[kp] = [
                    sent for sent in sentences if kp.lower() in sent.lower()
                ]

            # Return result without sentences and keyphrases in output
            return {
                "file_name": os.path.basename(input_path),
                "content": {
                    "language": language,
                    "sentences": sentences,
                    "keyphrases": keyphrases,
                    "keyphrase_clusters": keyphrase_clusters,
                    "sections": sections,
                    "tables": tables,
                    "structure": structure,
                },
                "display": {"normalized_text": normalized},
            }

        except Exception as e:
            raise Exception(f"Error processing file {input_path}: {str(e)}")


def process_documents(input_dir: str, output_dir: str) -> None:
    """Process all files in the input directory and save results to output directory."""
    logging.info(f"Starting document processing pipeline")
    logging.info(f"Input directory: {input_dir}")
    logging.info(f"Output directory: {output_dir}")

    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory does not exist: {input_dir}")

    os.makedirs(output_dir, exist_ok=True)
    processor = DocumentProcessor()
    results: List[Dict] = []

    try:
        files = os.listdir(input_dir)
        logging.info(f"Found {len(files)} files to process")

        for filename in files:
            file_path = os.path.join(input_dir, filename)
            ext = os.path.splitext(filename)[1].lower()

            logging.info(f"Processing file: {filename}")
            logging.info(f"File extension: {ext}")

            mime_type = next(
                (mime for mime, e in processor.SUPPORTED_TYPES.items() if e == ext),
                None,
            )

            if mime_type:
                try:
                    logging.info(f"MIME type identified: {mime_type}")
                    result = processor.process_file(file_path, mime_type)
                    results.append(result)
                    logging.info(f"Successfully processed {filename}")
                except Exception as e:
                    error_msg = f"Error processing {filename}: {str(e)}"
                    logging.error(error_msg)
                    results.append({"file_name": filename, "error": error_msg})
            else:
                logging.warning(f"Unsupported file type: {filename}")

        output_path = os.path.join(output_dir, "processing_results.json")
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        logging.info(f"Results saved to {output_path}")

    except Exception as e:
        error_msg = f"Fatal error during document processing: {str(e)}"
        logging.error(error_msg)
        raise Exception(error_msg)


if __name__ == "__main__":
    import codecs
    import sys

    # Configure stdout to handle Unicode
    if sys.stdout.encoding != "utf-8":
        sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer, "strict")
    if sys.stderr.encoding != "utf-8":
        sys.stderr = codecs.getwriter("utf-8")(sys.stderr.buffer, "strict")

    # Set up logging with UTF-8 encoding
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler("document_processing.log", encoding="utf-8"),
        ],
    )

    if len(sys.argv) != 3:
        logging.error(
            "Usage: python document_processor.py <input_directory> <output_directory>"
        )
        sys.exit(1)

    input_dir = sys.argv[1]
    output_dir = sys.argv[2]

    if not os.path.isdir(input_dir):
        logging.error(
            f"Error: The provided input path '{input_dir}' is not a valid directory."
        )
        sys.exit(1)

    process_documents(input_dir, output_dir)
