import pandas as pd
import mysql.connector
import os
import re
from datetime import datetime
import logging
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class TickDataImporter:
    def __init__(self):
        """Initialize database connection"""
        self.db_config = {
            'host': os.getenv('MYSQL_HOST', 'localhost'),
            'user': os.getenv('MYSQL_USER'),
            'password': os.getenv('MYSQL_PASSWORD'),
            'database': os.getenv('MYSQL_DATABASE', 'Metatrader')
        }
        self.conn = None
        self.cursor = None

    def connect(self):
        """Establish database connection"""
        try:
            self.conn = mysql.connector.connect(**self.db_config)
            self.cursor = self.conn.cursor()
            logger.info("Successfully connected to MySQL database")
            return True
        except mysql.connector.Error as err:
            logger.error(f"Failed to connect to MySQL: {err}")
            return False

    def parse_filename(self, filename):
        """Extract information from filename"""
        # Remove .csv extension
        filename = os.path.splitext(filename)[0]
        
        # Split filename by underscore
        parts = filename.split('_')
        
        if len(parts) < 3:
            raise ValueError(f"Invalid filename format: {filename}. Expected at least 3 parts separated by underscore.")
        
        # Last two parts are always dates
        start_date = parts[-2]  # Second to last part
        end_date = parts[-1]    # Last part
        
        # Everything before the dates is the underlier name
        underlier = '_'.join(parts[:-2])
        
        # Convert dates to datetime objects
        try:
            start_datetime = datetime.strptime(start_date, '%Y%m%d%H%M')
            end_datetime = datetime.strptime(end_date, '%Y%m%d%H%M')
        except ValueError as e:
            raise ValueError(f"Invalid date format in filename: {e}")
        
        # Keep original date strings for table name
        start_date_str = parts[-2]
        end_date_str = parts[-1]
        table_name = f"{underlier}_TICKDATA_{start_date_str}_{end_date_str}"
        
        return {
            'table_name': table_name,
            'start_date': start_datetime,
            'end_date': end_datetime
        }

    def create_table(self, table_name):
        """Create table if it doesn't exist"""
        create_table_sql = f"""
        CREATE TABLE IF NOT EXISTS `{table_name}` (
            `date` DATE NOT NULL,
            `time` TIME(3) NOT NULL,
            `bid` DECIMAL(15,6) NULL,
            `ask` DECIMAL(15,6) NULL,
            `last` DECIMAL(15,6) NULL,
            `volume` INT NULL,
            `flags` INT NULL,
            PRIMARY KEY (`date`, `time`)
        ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
        """
        
        try:
            self.cursor.execute(create_table_sql)
            self.conn.commit()
            logger.info(f"Table {table_name} created or already exists")
            return True
        except mysql.connector.Error as err:
            logger.error(f"Failed to create table: {err}")
            return False

    def process_csv(self, csv_path, chunk_size=1000):
        """Process CSV file and import data to MySQL"""
        try:
            # Parse filename
            file_info = self.parse_filename(os.path.basename(csv_path))
            table_name = file_info['table_name']
            
            # Create table
            if not self.create_table(table_name):
                return False
            
            # Read CSV in chunks
            chunks = pd.read_csv(
                csv_path,
                sep='\t',
                names=['date', 'time', 'bid', 'ask', 'last', 'volume', 'flags'],
                skiprows=1,  # Skip the header row with <DATE>, <TIME>, etc.
                chunksize=chunk_size,
                na_values=['']  # Treat empty strings as NULL
            )
            
            total_rows = 0
            for chunk in chunks:
                # Convert date format
                chunk['date'] = pd.to_datetime(chunk['date'], format='%Y.%m.%d').dt.date
                
                # Prepare values for insertion
                values = []
                for _, row in chunk.iterrows():
                    value = (
                        row['date'],
                        row['time'],
                        row['bid'] if pd.notna(row['bid']) else None,
                        row['ask'] if pd.notna(row['ask']) else None,
                        row['last'] if pd.notna(row['last']) else None,
                        row['volume'] if pd.notna(row['volume']) else None,
                        row['flags'] if pd.notna(row['flags']) else None
                    )
                    values.append(value)
                
                # Insert data
                insert_sql = f"""
                INSERT IGNORE INTO {table_name} 
                (date, time, bid, ask, last, volume, flags)
                VALUES (%s, %s, %s, %s, %s, %s, %s)
                """
                
                try:
                    self.cursor.executemany(insert_sql, values)
                    self.conn.commit()
                    total_rows += len(values)
                    logger.info(f"Inserted {len(values)} rows. Total: {total_rows}")
                except mysql.connector.Error as err:
                    logger.error(f"Failed to insert chunk: {err}")
                    self.conn.rollback()
                    continue
            
            if total_rows > 0:
                logger.info(f"Completed importing {total_rows} rows to {table_name}")
                self.analyze_table(table_name)
                return True
            return False
            
        except Exception as e:
            logger.error(f"Error processing CSV: {e}")
            return False

    def analyze_table(self, table_name):
        """Analyze table to update statistics"""
        try:
            self.cursor.execute(f"ANALYZE TABLE {table_name}")
            logger.info(f"Table {table_name} analyzed successfully")
            return True
        except mysql.connector.Error as err:
            logger.error(f"Failed to analyze table: {err}")
            return False
        
    def close(self):
        """Close database connection"""
        if self.cursor:
            self.cursor.close()
        if self.conn:
            self.conn.close()
            logger.info("Database connection closed")

def main():
    # Initialize importer
    importer = TickDataImporter()
    
    # Connect to database
    if not importer.connect():
        return
    
    try:
        # Process CSV file
        csv_file = "/home/a3/ETHUSD_202501040045_202504101350.csv"  # Replace with your CSV file
        if importer.process_csv(csv_file):
            logger.info("CSV import completed successfully")
        else:
            logger.error("Failed to import CSV")
    
    finally:
        # Close connection
        importer.close()

if __name__ == "__main__":
    main()
