import os
import mysql.connector
import requests
import pandas as pd
import numpy as np
from PIL import Image
from dotenv import load_dotenv
from datetime import datetime, timedelta
import logging

def fetch_and_process_stock_data(input_start_days, input_end_days):
    # Set up logging
    logging.basicConfig(level=logging.INFO, 
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler("script_log.txt"),
                            logging.StreamHandler()
                        ])

    logging.info("Script started")

    # Load environment variables
    load_dotenv()
    API_TOKEN = os.getenv('MY_API_TOKEN')
    logging.info("Environment variables loaded")

    # Database connection details
    STOCK_DB_HOST = os.getenv('STOCK_DB_HOST')
    STOCK_DB_USER = os.getenv('STOCK_DB_USER')
    STOCK_DB_PASSWORD = os.getenv('STOCK_DB_PASSWORD')
    STOCK_DB_NAME = os.getenv('STOCK_DB_NAME')
    logging.info("Database connection details retrieved")

    # Calculate date range
    end_date = datetime.now().date() - timedelta(days=input_end_days)
    while end_date.weekday() > 4:  # Adjust to last weekday
        end_date -= timedelta(days=1)
    start_date = end_date - timedelta(days=input_start_days)
    
    # Store these for return values
    output_start_date = start_date
    output_end_date = end_date
    
    logging.info(f"Date range calculated: {start_date} to {end_date}")

    # Function to fetch stock names from database
    def fetch_stock_names(cursor):
        logging.info("Fetching stock names from database")
        try:
            cursor.execute("SELECT query AS stock_name FROM predict_names_100B ORDER BY locMarketCAPUSD DESC")
            stock_names = [row[0] for row in cursor.fetchall()]
            logging.info(f"Successfully fetched {len(stock_names)} stock names")
            return stock_names
        except mysql.connector.Error as error:
            logging.error(f"Error fetching stock names: {error}")
            return []

    # Function to fetch data for a stock
    def fetch_stock_data(stock_name):
        logging.info(f"Fetching data for {stock_name}")
        url = f"https://eodhd.com/api/eod/{stock_name}?from={start_date}&to={end_date}&period=d&api_token={API_TOKEN}&fmt=json"
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            logging.info(f"Successfully fetched data for {stock_name}")
            return data
        else:
            logging.error(f"Error fetching data for {stock_name}: {response.status_code}")
            return None

    # Function to insert EOD data into the database
    def insert_eod_data(connection, cursor, stock_name, data):
        insert_query = """
        INSERT INTO eod_data 
        (TradeDate, StockName, Open, High, Low, Close, AdjustedClose, Volume)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
        ON DUPLICATE KEY UPDATE
        Open = VALUES(Open),
        High = VALUES(High),
        Low = VALUES(Low),
        Close = VALUES(Close),
        AdjustedClose = VALUES(AdjustedClose),
        Volume = VALUES(Volume)
        """
    
        rows_to_insert = []
        for row in data:
            rows_to_insert.append((
                row['date'],
                stock_name,
                round(float(row['open']), 2),
                round(float(row['high']), 2),
                round(float(row['low']), 2),
                round(float(row['close']), 2),
                round(float(row['adjusted_close']), 6),
                int(row['volume'])
            ))
        
        try:
            cursor.executemany(insert_query, rows_to_insert)
            connection.commit()
            logging.info(f"Inserted/Updated {len(rows_to_insert)} rows for {stock_name} in eod_data table")
        except mysql.connector.Error as error:
            logging.error(f"Error inserting data for {stock_name} into eod_data table: {error}")
            connection.rollback()

    # Function to calculate percentage change using adjusted_close
    def calculate_percentage_change(data):
        logging.info("Calculating percentage change using adjusted_close")
        df = pd.DataFrame(data)
        df['date'] = pd.to_datetime(df['date'])
        df = df.sort_values('date')
        df['perc_change'] = df['adjusted_close'].pct_change()
        return df

    try:
        connection = mysql.connector.connect(
            host=STOCK_DB_HOST,
            user=STOCK_DB_USER,
            password=STOCK_DB_PASSWORD,
            database=STOCK_DB_NAME
        )
        cursor = connection.cursor()
        logging.info("Connected to STOCKNAME database")

        # Ensure eod_data table exists
        create_eod_table_query = """
        CREATE TABLE IF NOT EXISTS eod_data (
            TradeDate DATE,
            StockName VARCHAR(255),
            Open DECIMAL(10,2),
            High DECIMAL(10,2),
            Low DECIMAL(10,2),
            Close DECIMAL(10,2),
            AdjustedClose DECIMAL(20,6),
            Volume BIGINT,
            PRIMARY KEY (TradeDate, StockName)
        ) ENGINE=InnoDB
        """
        cursor.execute(create_eod_table_query)
        logging.info("Ensured eod_data table exists")

        # Fetch stock names
        stock_names = fetch_stock_names(cursor)
        logging.info(f"Fetched {len(stock_names)} stock names")

        # Fetch data and calculate percentage change for each stock
        stock_data = {}
        for i, stock in enumerate(stock_names, 1):
            logging.info(f"Processing stock {i}/{len(stock_names)}: {stock}")
            data = fetch_stock_data(stock)
            if data:
                # Insert into eod_data table
                insert_eod_data(connection, cursor, stock, data)
                # Calculate percentage change
                stock_data[stock] = calculate_percentage_change(data)
            logging.info(f"Completed processing for {stock}")

        logging.info("Constructing complete DataFrame of percentage changes")

        # Create a date range for all possible trading days
        all_dates = pd.date_range(start=start_date, end=end_date)

        # Initialize an empty DataFrame with dates as index
        complete_df = pd.DataFrame(index=all_dates)

        # Fill the DataFrame with percentage changes for each stock
        for stock in stock_names:
            if stock in stock_data:
                df = stock_data[stock]
                complete_df[stock] = df.set_index('date')['perc_change']

        # Sort the DataFrame by date
        complete_df.sort_index(inplace=True)

        # Remove rows where all values are NaN (non-trading days)
        complete_df.dropna(how='all', inplace=True)

        # Fill NaN values with 0 (for days when a particular stock didn't trade)
        complete_df.fillna(0, inplace=True)

        # Drop the percentage_changes table if it exists
        cursor.execute("DROP TABLE IF EXISTS percentage_changes")
        logging.info("Dropped existing percentage_changes table")

        # Create the new percentage_changes table
        create_table_query = """
        CREATE TABLE percentage_changes (
            TradeDate DATE,
            {}
        ) ENGINE=InnoDB
        """.format(", ".join([f"`{col}` FLOAT" for col in complete_df.columns]))

        cursor.execute(create_table_query)
        logging.info("Created new percentage_changes table")

        # Prepare the data for insertion
        columns = ['TradeDate'] + list(complete_df.columns)
        placeholders = ', '.join(['%s'] * len(columns))
        
        insert_query = f"""
        INSERT INTO percentage_changes ({', '.join(columns)})
        VALUES ({placeholders})
        """

        # Convert DataFrame to list of tuples for insertion
        data_to_insert = [tuple([index.date()] + list(row)) for index, row in complete_df.iterrows()]

        # Insert the data
        cursor.executemany(insert_query, data_to_insert)
        connection.commit()
        logging.info(f"Inserted {len(data_to_insert)} rows into percentage_changes table")

    except mysql.connector.Error as error:
        logging.error(f"Error with database operations: {error}")

    finally:
        if connection.is_connected():
            cursor.close()
            connection.close()
            logging.info("STOCKNAME database connection closed")

    # Save the complete DataFrame to a CSV file
    csv_filename = 'stock_percentage_changes.csv'
    complete_df.to_csv(csv_filename)
    logging.info(f"Complete DataFrame saved to {csv_filename}")

    logging.info("Script completed")
    
    return [output_start_date, output_end_date]

if __name__ == "__main__":
    # Example usage
    input_start_days = 360*3
    input_end_days = 1
    start_date, end_date = fetch_and_process_stock_data(input_start_days, input_end_days)
    print(f"Start Date: {start_date}")
    print(f"End Date: {end_date}")

#  INPUT:
#   Start Days : 360*3
#   End Days : 1
#
#   OUTPUT:
#   start_date
#   end_date
