#!/usr/bin/env python3

import os
import json
import re
import logging
import tempfile
from typing import Dict, List, Tuple

import pandas as pd
import yaml
import nltk
from nltk.tokenize import sent_tokenize as nltk_sent_tokenize, word_tokenize
from pypdf import PdfReader
from docx import Document
import win32com.client
import pythoncom

from hazm import Normalizer, sent_tokenize as hazm_sent_tokenize
import arabic_reshaper
from bidi.algorithm import get_display

# Download required NLTK data once at startup
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('stopwords', quiet=True)

class DocumentProcessor:
    # Supported MIME types
    SUPPORTED_TYPES = {
        'application/pdf': '.pdf',
        'application/msword': '.doc',
        'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
        'text/csv': '.csv',
        'application/vnd.ms-excel': '.xls',
        'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
        'text/plain': '.txt',
        'text/markdown': '.md',
        'application/json': '.json',
        'application/yaml': '.yml',
        'application/x-yaml': '.yaml'
    }

    def __init__(self):
        self.normalizer = Normalizer()
        # Initialize COM for Word and prepare the Word application on demand
        pythoncom.CoInitialize()
        self.word = None

    def __del__(self):
        if self.word:
            self.word.Quit()
        pythoncom.CoUninitialize()

    def convert_to_english_numbers(self, text: str) -> str:
        if not text:
            return ""
        # Map of Persian/Arabic numbers to English
        number_map = {
            '۰': '0',
            '۱': '1',
            '۲': '2',
            '۳': '3',
            '۴': '4',
            '۵': '5',
            '۶': '6',
            '۷': '7',
            '۸': '8',
            '۹': '9',
            '،': ',',  # Add Persian comma conversion
        }
        
        # Convert Persian digits and comma
        converted = ''.join(number_map.get(c, c) for c in text)
        
        # Clean up number formatting
        # This regex finds numbers with spaces/commas and standardizes their format
        import re
        converted = re.sub(r'(\d+)[\s,،]+(\d{3})', r'\1,\2', converted)
        
        return converted

    def is_header(self, paragraph) -> bool:
        """Detect if a paragraph is likely a header based on formatting"""
        try:
            # For DOCX
            if hasattr(paragraph, 'style'):
                # Check if it's a heading style
                if paragraph.style.name.startswith('Heading'):
                    return True
                # Check if it's bold
                if hasattr(paragraph, 'runs'):
                    return any(run.bold for run in paragraph.runs)
            # For other formats, use heuristics
            text = paragraph.text if hasattr(paragraph, 'text') else str(paragraph)
            # Check if text ends with common header indicators
            if text.strip().endswith((':','：')):
                return True
            # Check if text is short and followed by newlines
            if len(text.strip()) < 50 and text.count('\n') > 1:
                return True
        except Exception:
            pass
        return False

    def extract_structure(self, paragraphs) -> List[Dict]:
        """Extract structured content with headers and their associated text"""
        structure = []
        current_header = None
        current_content = []
        
        for para in paragraphs:
            if self.is_header(para):
                if current_header:
                    structure.append({'header': current_header, 'content': '\n'.join(current_content)})
                current_header = para.text if hasattr(para, 'text') else str(para)
                current_content = []
            else:
                current_content.append(para.text if hasattr(para, 'text') else str(para))
        
        if current_header:  # Add the last section
            structure.append({'header': current_header, 'content': '\n'.join(current_content)})
        
        return structure

    def clean_table_text(self, text: str) -> str:
        """Clean up text extracted from tables"""
        if not text:
            return ""
        # Replace line breaks with spaces
        text = text.replace("\n", " ")
        # Normalize Persian/Arabic text
        text = self.normalizer.normalize(text)
        # Convert numbers to English
        text = self.convert_to_english_numbers(text)
        return text

    def extract_sections(self, text: str) -> Dict[str, List[str]]:
        """Extract sections from text based on short header-like lines.
        
        Args:
            text (str): Input text to process.
            
        Returns:
            Dict[str, List[str]]: Dictionary with section headers as keys and their content as lists of strings.
        """
        lines = text.split("\n")
        data = {}
        current_label = None

        for line in lines:
            line = line.strip()
            # Skip empty lines and bullet points
            if not line or "•" in line or "o" in line:
                continue

            # Potential header: short line without colon
            if len(line) < 50 and ":" not in line:
                current_label = line
                data[current_label] = []
            # Content line: append to current section if exists
            elif current_label:
                data[current_label].append(line)

        return data


    def detect_language(self, text: str) -> str:
        """Detect if text is primarily Persian or English"""
        persian_pattern = re.compile(r'[۰-۹آ-ی]')
        persian_count = len(persian_pattern.findall(text))
        english_pattern = re.compile(r'[a-zA-Z]')
        english_count = len(english_pattern.findall(text))
        return 'fa' if persian_count > english_count else 'en'

    def extract_doc_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text from DOC files using Word application"""
        try:
            if not self.word:
                self.word = win32com.client.Dispatch('Word.Application')
                self.word.Visible = False

            # Convert DOC to DOCX using a temporary file
            doc = self.word.Documents.Open(os.path.abspath(file_path))
            temp_docx = tempfile.mktemp(suffix='.docx')
            doc.SaveAs2(temp_docx, FileFormat=16)  # 16 = DOCX format
            doc.Close()

            # Process as DOCX
            docx = Document(temp_docx)
            text = "\n".join([paragraph.text for paragraph in docx.paragraphs])
            structure = self.extract_structure(docx.paragraphs)
            
            tables = []
            # Extract tables
            for idx, table in enumerate(docx.tables):
                headers = [self.clean_table_text(cell.text.strip())
                           for cell in table.rows[0].cells]

                data = []
                for row in table.rows[1:]:
                    row_data = [self.clean_table_text(cell.text.strip()) for cell in row.cells]
                    if any(row_data):  # Only add non-empty rows
                        data.append(row_data)
                if headers and data:
                    tables.append({
                        'table_index': idx,
                        'headers': headers,
                        'data': data
                    })

            os.unlink(temp_docx)
            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from DOC {file_path}: {str(e)}")

    def extract_docx_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text from DOCX documents"""
        try:
            doc = Document(file_path)
            text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
            structure = self.extract_structure(doc.paragraphs)
            tables = []

            for idx, table in enumerate(doc.tables):
                headers = [self.clean_table_text(cell.text.strip())
                           for cell in table.rows[0].cells]

                data = []
                for row in table.rows[1:]:
                    row_data = [self.clean_table_text(cell.text.strip()) for cell in row.cells]
                    if any(row_data):
                        data.append(row_data)
                if headers and data:
                    tables.append({
                        'table_index': idx,
                        'headers': headers,
                        'data': data
                    })

            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from DOCX {file_path}: {str(e)}")

    def extract_pdf_text(self, file_path: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text and tables from PDF documents"""
        try:
            reader = PdfReader(file_path)
            text = ""
            tables = []
            structure = []

            # Extract text from all pages
            for page in reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text += page_text + "\n"

            # Extract tables using Camelot (optional feature)
            try:
                import camelot
                pdf_tables = camelot.read_pdf(file_path, pages='all', flavor='stream')
                for idx, table in enumerate(pdf_tables):
                    df = table.df
                    if not df.empty and len(df) > 1:
                        headers = df.iloc[0].tolist()
                        data = df.iloc[1:].values.tolist()

                        processed_headers = []
                        for header in headers:
                            if isinstance(header, str):
                                header_clean = self.clean_table_text(header)
                                reshaped = arabic_reshaper.reshape(header_clean)
                                header = get_display(reshaped)
                            processed_headers.append(header)

                        processed_data = []
                        for row in data:
                            processed_row = []
                            for cell in row:
                                if isinstance(cell, str):
                                    cell_clean = self.clean_table_text(cell)
                                    reshaped = arabic_reshaper.reshape(cell_clean)
                                    cell = get_display(reshaped)
                                processed_row.append(cell)
                            processed_data.append(processed_row)

                        tables.append({
                            'table_index': idx,
                            'headers': processed_headers,
                            'data': processed_data
                        })
            except Exception as e:
                print(f"Warning: Table extraction failed: {str(e)}")

            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from PDF {file_path}: {str(e)}")

    def extract_text(self, file_path: str, mime_type: str) -> Tuple[str, List[Dict], List[Dict]]:
        """Extract text and tables from various file formats"""
        try:
            tables = []
            text = ""
            structure = []

            if mime_type == 'application/pdf':
                text, tables, structure = self.extract_pdf_text(file_path)
            elif mime_type == 'application/msword':
                text, tables, structure = self.extract_doc_text(file_path)
            elif mime_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
                text, tables, structure = self.extract_docx_text(file_path)
            elif mime_type in ['text/csv', 'application/vnd.ms-excel', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet']:
                if mime_type == 'text/csv':
                    df = pd.read_csv(file_path)
                    text = df.to_string()
                    if not df.empty:
                        tables.append({
                            'headers': df.columns.tolist(),
                            'data': df.values.tolist()
                        })
                else:
                    xl = pd.ExcelFile(file_path)
                    for sheet_name in xl.sheet_names:
                        df = pd.read_excel(file_path, sheet_name=sheet_name)
                        text += f"\nSheet: {sheet_name}\n" + df.to_string()
                        if not df.empty:
                            tables.append({
                                'sheet_name': sheet_name,
                                'headers': df.columns.tolist(),
                                'data': df.values.tolist()
                            })
            elif mime_type in ['text/plain', 'text/markdown']:
                with open(file_path, 'r', encoding='utf-8') as f:
                    text = f.read()
            elif mime_type == 'application/json':
                with open(file_path, 'r', encoding='utf-8') as f:
                    content = json.load(f)
                    text = json.dumps(content, ensure_ascii=False, indent=2)
            elif mime_type in ['application/yaml', 'application/x-yaml']:
                with open(file_path, 'r', encoding='utf-8') as f:
                    content = yaml.safe_load(f)
                    text = yaml.dump(content, allow_unicode=True)
            return text, tables, structure

        except Exception as e:
            raise Exception(f"Error extracting content from {file_path}: {str(e)}")

    def extract_keyphrases(self, text: str, language: str) -> List[str]:
        """Extract key phrases from text based on language."""
        # Convert all numbers to English
        text = self.convert_to_english_numbers(text)
        try:
            if language == 'fa':
                words = text.split()
                phrases = [word for word in words if len(word) > 2]
                return phrases[:10]
            else:
                words = word_tokenize(text)
                tagged = nltk.pos_tag(words)
                phrases = []
                current_phrase = []
                for word, tag in tagged:
                    if tag.startswith(('NN', 'JJ')):
                        current_phrase.append(word)
                    else:
                        if current_phrase:
                            phrases.append(" ".join(current_phrase))
                            current_phrase = []
                if current_phrase:
                    phrases.append(" ".join(current_phrase))
                return phrases[:10]
        except Exception as e:
            raise Exception(f"Error extracting keyphrases: {str(e)}")

    def process_file(self, file_path: str, mime_type: str) -> Dict:
        """Process file: extract text, tables, structure, and keyphrases."""
        text, tables, structure = self.extract_text(file_path, mime_type)
        normalized = self.convert_to_english_numbers(text)
        sections = self.extract_sections(normalized)
        language = self.detect_language(normalized)
        sentences = nltk_sent_tokenize(normalized) if language == 'en' else hazm_sent_tokenize(normalized)
        keyphrases = set(self.extract_keyphrases(normalized, language))
        for section in structure:
            header_text = section.get('header', '')
            keyphrases.update(self.extract_keyphrases(header_text, language))
        keyphrases = list(keyphrases)[:10]
        # Kluster sentences under key phrases
        keyphrase_clusters = {}
        for kp in keyphrases:
            keyphrase_clusters[kp] = [sent for sent in sentences if kp.lower() in sent.lower()]
        return {
            "file_name": os.path.basename(file_path),
            "content": {
                "language": language,
                "sentences": sentences,
                "keyphrases": keyphrases,
                "keyphrase_clusters": keyphrase_clusters,
                "sections": sections,
                "tables": tables,
                "structure": structure
            },
            "display": {
                "normalized_text": normalized
            }
        }

# --- Improved Test Routine ---

def test_document_processing() -> None:
    """
    Process all files in the raw_file folder and log results.
    Enhancements:
      - Uses logging instead of print statements.
      - Adds type annotations for clarity.
      - Wraps output file saving in try/except for improved error handling.
    """
    test_dir = os.path.join("src", "test", "raw_file")
    output_dir = "test_output"
    os.makedirs(output_dir, exist_ok=True)

    processor = DocumentProcessor()
    results: List[Dict] = []

    for filename in os.listdir(test_dir):
        file_path = os.path.join(test_dir, filename)
        ext = os.path.splitext(filename)[1].lower()
        mime_type = next(
            (mime for mime, e in processor.SUPPORTED_TYPES.items() if e == ext),
            None
        )

        if mime_type:
            try:
                logging.info(f"Processing {filename}...")
                result = processor.process_file(file_path, mime_type)
                results.append(result)
                logging.info(f"Successfully processed {filename}")
            except Exception as e:
                logging.error(f"Error processing {filename}: {str(e)}")
                results.append({
                    "file_name": filename,
                    "error": str(e)
                })
        else:
            logging.warning(f"Unsupported file type: {filename}")

    output_path = os.path.join(output_dir, "processing_results.json")
    try:
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        logging.info(f"Results saved to {output_path}")
    except Exception as e:
        logging.error(f"Error saving results to {output_path}: {str(e)}")

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S"
    )
    test_document_processing()