import requests
import base64
import json
from datetime import datetime, timedelta
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize
import mysql.connector
from dotenv import load_dotenv
import os
import pandas as pd
import logging
from datetime import datetime, timedelta

def get_next_working_day():
    current_date = datetime.now()
    
    # Check if the current day is Saturday (5) or Sunday (6)
    while current_date.weekday() in [5, 6]:
        current_date += timedelta(days=1)
    
    return current_date.strftime('%Y-%m-%d')

# Corrected list of stock names
stock_names = [
    'AAPL', 'NVDA', 'MSFT', 'GOOGL', 'AMZN', 'META', 'TSM', 'TSLA', 'AVGO', 'LLY',
    'WMT', 'JPM', 'V', 'XOM', 'UNH', 'TCEHY', 'NVO', 'ORCL', 'MA', 'HD',
    'JNJ', 'PG', 'COST', 'ABBV', 'BAC', 'NFLX', 'KO', 'ASML', 'CRM', 'SAP',
    'CVX', 'MRK', 'TMUS', 'AMD', 'BABA', 'AZN', 'PEP', 'NVS', 'TM', 'ACN',
    'CSCO', 'WFC', 'TMO', 'IBM', 'MCD', 'ADBE', 'BX', 'PM', 'ABT', 'GE',
    'QCOM', 'NOW', 'TXN', 'MS', 'AXP', 'CAT', 'DHR', 'ISRG', 'VZ', 'RY',
    'NEE', 'PDD', 'DIS', 'INTU', 'AMGN', 'GS', 'HDB', 'PFE', 'UBER', 'CMCSA',
    'T', 'HSBC', 'UL', 'LOW', 'AMAT', 'SPGI', 'TTE', 'BLK', 'UNP', 'BKNG',
    'BHP', 'HON', 'PGR', 'SYK', 'ETN', 'SNY', 'SCHW', 'LMT', 'KKR', 'BSX',
    'TJX', 'BUD', 'ANET', 'VRTX', 'COP', 'C', 'MDT', 'PANW', 'MU', 'NKE',
    'ADP', 'CB', 'ADI', 'DE', 'PLD', 'SBUX', 'UPS', 'GILD', 'MMC', 'IBN',
    'SONY', 'BMY', 'RIO', 'UBS', 'MELI', 'AMT', 'HCA', 'REGN', 'SHOP', 'PLTR',
    'SO'
]

def encode_image(image_path):
    """Helper function to encode image files to base64."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def inpaint(prediction_width=768, prediction_height=256):
    """
    Performs image inpainting using the GetImg.ai API.
    Uses original_input.png and mask_input.png as inputs,
    and saves the result as prediction_stocks_API.png.

    Args:
        prediction_width (int): Width of the output image. Defaults to 768.
        prediction_height (int): Height of the output image. Defaults to 256.
    """
    try:
        # Encode both images
        image = encode_image("original_input.png")
        mask_image = encode_image("mask_input.png")

        # API endpoint and headers
        url = "https://api.getimg.ai/v1/stable-diffusion/inpaint"
        
        headers = {
            "Authorization": "Bearer key-3JbcgLbOZ0cjXKIAQ4N4QM84fPlIynJ3SHlqqrcGXuO2YqA36rDDJoW5ZnPGLWd8vuABuWaH9S896vvkWedt1K8xr2TxIiWP",
            "Content-Type": "application/json"
        }

        # Prepare the payload
        payload = {
            "image": image,
            "mask_image": mask_image,
            "model": "stable-diffusion-v1-5-inpainting",
            "prompt": "continue this pixel pattern, use only pixels with a single color that belongs to this colormap",
            "negative_prompt": "artifacts, blurry, distorted, stretch, charts, graphs, documents",
            "steps": 100,
            "cfg_scale": 7,
            "width": prediction_width,
            "height": prediction_height,
            "output_format": "png",
            "scheduler": "euler_a"
        }

        # Make the API request
        response = requests.post(url, headers=headers, json=payload)
        
        if response.status_code == 200:
            # Get the image data from response
            result = response.json()
            image_bytes = base64.b64decode(result['image'])
            
            # Save the original result
            with open("prediction_stocks_API.png", "wb") as f:
                f.write(image_bytes)
            
            # Save a dated copy in predictions directory
            current_date = datetime.now().strftime("%Y-%m-%d")
            dated_filename = f"/home/a3/python_scripts/predictionAI/model1/predictions/prediction_stocks_API_{current_date}.png"
            with open(dated_filename, "wb") as f:
                f.write(image_bytes)
        else:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

    except Exception as e:
        raise Exception(f"Inpainting failed: {str(e)}")

def inverse_custom_tanh(y, k=0.01):
    """
    Inverse of the custom tanh function to get back original values.
    """
    scale = np.arctanh(0.5) / k
    return np.arctanh(y) / scale

def convert_from_viridis(image_path, predict_column=744, value_min=-0.32, value_max=0.51, offset_height=121):
    """
    Convert pixel colors from a specific column back to their decimal values using VIRIDIS scale.
    """
    img = Image.open(image_path)
    img_array = np.array(img)
    
    height, width, _ = img_array.shape
    column_index = predict_column
    start_row = height - offset_height
    
    target_pixels = img_array[start_row:, column_index]
    reversed_target_pixels = target_pixels[::-1]
    
    N = 256
    viridis = plt.cm.viridis
    viridis_colors = viridis(np.linspace(0, 1, N))[:, :3]
    
    def find_closest_value(pixel_color):
        pixel_color = pixel_color / 255.0
        distances = np.sqrt(np.sum((viridis_colors - pixel_color) ** 2, axis=1))
        closest_idx = np.argmin(distances)
        return closest_idx / (N - 1)
    
    normalized_values = np.array([find_closest_value(pixel) for pixel in reversed_target_pixels])
    intermediate_values = normalized_values * (value_max - value_min) + value_min
    original_values = inverse_custom_tanh(intermediate_values)
    
    return original_values

def transform_data(df):
    """
    Apply the transformation y = tanh(10.986122886681096 * x) to all values in the dataframe.
    """
    return np.tanh(10.986122886681096 * df)

def custom_tanh(x, k):
    """
    Custom tanh function where ±k input maps to ±0.5 tanh output.
    
    Parameters:
    x (float or array): Input value(s)
    k (float): The input value that should map to 0.5 in the tanh output
    
    Returns:
    float or array: Transformed value(s)
    """
    scale = np.arctanh(0.5) / k
    return np.tanh(scale * x)

def convert_to_image(tanh_convert_constant=0.01, input_x_multiple=12, input_y_multiple=4):
    """
    Converts stock data to images and returns key metrics.
    
    Parameters:
    tanh_convert_constant (float): The constant used in tanh conversion. Default is 0.01.
    
    Returns:
    list: [global_min, global_max, height, width]
    """
    # Load environment variables
    load_dotenv()

    # Database connection details
    db_user = os.getenv('DB_USER')
    db_password = os.getenv('DB_PASSWORD')
    db_host = os.getenv('DB_HOST')
    db_name = 'stocks'

    try:
        # Create database connection
        cnx = mysql.connector.connect(
            user=db_user,
            password=db_password,
            host=db_host,
            database=db_name
        )
        cursor = cnx.cursor()

        # Fetch data from the database
        query = "SELECT * FROM percentage_changes"
        cursor.execute(query)

        # Fetch all rows and column names
        rows = cursor.fetchall()
        columns = [i[0] for i in cursor.description]

        # Create DataFrame
        df = pd.DataFrame(rows, columns=columns)
        df.set_index('TradeDate', inplace=True)

        # Find  min and max of transformed data
        output_min = df.values.min()
        output_max = df.values.max()

        # Apply the transformation using the provided constant
        df_transformed = custom_tanh(df, tanh_convert_constant)

        # Find global min and max of transformed data
        global_min = df_transformed.values.min()
        global_max = df_transformed.values.max()

        # Get the dimensions of the data
        height, width = df_transformed.shape

        # Create a normalization object
        norm = Normalize(vmin=global_min, vmax=global_max)

        # Create an array to hold the RGB values
        rgb_array = np.zeros((height, width, 3), dtype=np.uint8)

        # Fill the array with colors based on normalized values
        for i in range(height):
            for j in range(width):
                value = df_transformed.iloc[i, j]
                color = plt.cm.viridis(norm(value))
                rgb_array[i, j] = np.array(color[:3]) * 255

        # Create and save images
        img = Image.fromarray(rgb_array, 'RGB')
        img.save('transformed_percentage_changes.png')

        # Create figure with labels and colorbar
        dpi = 150
        fig, ax = plt.subplots(figsize=(width/dpi, height/dpi), dpi=dpi)
        im = ax.imshow(rgb_array, aspect='auto', interpolation='nearest')

        # Set axis labels and ticks
        ax.set_xlabel('Stocks')
        ax.set_ylabel('Trade Date')

        x_ticks = [0, width // 2, width - 1]
        y_ticks = [0, height // 2, height - 1]
        ax.set_xticks(x_ticks)
        ax.set_yticks(y_ticks)

        x_labels = [df.columns[0], df.columns[width // 2], df.columns[-1]]
        y_labels = [df.index[0], df.index[height // 2], df.index[-1]]
        ax.set_xticklabels(x_labels, rotation=45, ha='right')
        ax.set_yticklabels(y_labels)

        # Add colorbar
        cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap='viridis'), ax=ax)
        cbar.set_label('Transformed Percentage Change')

        plt.tight_layout()
        plt.savefig('data_with_labels_and_colorbar.png', dpi=dpi, bbox_inches='tight', pad_inches=0)

        prediction_x = 64 * input_x_multiple
        prediction_y = 64 * input_y_multiple

        # Create original_input.png
        white_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
        rotated_img = img.transpose(Image.ROTATE_90)
        position = (0, 256 - rotated_img.height)
        white_background.paste(rotated_img, position)
        white_background.save('original_input.png', dpi=(300, 300))

        # Create mask_input.png
        mask_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
        black_rectangle = Image.new('RGB', rotated_img.size, color='black')
        mask_background.paste(black_rectangle, position)
        mask_background.save('mask_input.png', dpi=(300, 300))

        return [global_min, global_max, height, width, output_min, output_max, prediction_x, prediction_y]

    except mysql.connector.Error as err:
        raise Exception(f"Database error: {err}")

    finally:
        if 'cursor' in locals():
            cursor.close()
        if 'cnx' in locals():
            cnx.close()

def fetch_stock_names(cursor):
    """Fetch stock names from database."""
    logging.info("Fetching stock names from database")
    try:
        cursor.execute("SELECT query AS stock_name FROM predict_names_100B ORDER BY locMarketCAPUSD DESC")
        stock_names = [row[0] for row in cursor.fetchall()]
        logging.info(f"Successfully fetched {len(stock_names)} stock names")
        return stock_names
    except mysql.connector.Error as error:
        logging.error(f"Error fetching stock names: {error}")
        return []

def fetch_stock_data(stock_name, start_date, end_date, API_TOKEN):
    """Fetch data for a stock."""
    logging.info(f"Fetching data for {stock_name}")
    url = f"https://eodhd.com/api/eod/{stock_name}?from={start_date}&to={end_date}&period=d&api_token={API_TOKEN}&fmt=json"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        logging.info(f"Successfully fetched data for {stock_name}")
        return data
    else:
        logging.error(f"Error fetching data for {stock_name}: {response.status_code}")
        return None

def insert_eod_data(connection, cursor, stock_name, data):
    """Insert EOD data into the database."""
    insert_query = """
    INSERT INTO eod_data 
    (TradeDate, StockName, Open, High, Low, Close, AdjustedClose, Volume)
    VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
    ON DUPLICATE KEY UPDATE
    Open = VALUES(Open),
    High = VALUES(High),
    Low = VALUES(Low),
    Close = VALUES(Close),
    AdjustedClose = VALUES(AdjustedClose),
    Volume = VALUES(Volume)
    """

    rows_to_insert = []
    for row in data:
        rows_to_insert.append((
            row['date'],
            stock_name,
            round(float(row['open']), 2),
            round(float(row['high']), 2),
            round(float(row['low']), 2),
            round(float(row['close']), 2),
            round(float(row['adjusted_close']), 6),
            int(row['volume'])
        ))
    
    try:
        cursor.executemany(insert_query, rows_to_insert)
        connection.commit()
        logging.info(f"Inserted/Updated {len(rows_to_insert)} rows for {stock_name} in eod_data table")
    except mysql.connector.Error as error:
        logging.error(f"Error inserting data for {stock_name} into eod_data table: {error}")
        connection.rollback()

def calculate_percentage_change(data):
    """Calculate percentage change using adjusted_close."""
    logging.info("Calculating percentage change using adjusted_close")
    df = pd.DataFrame(data)
    df['date'] = pd.to_datetime(df['date'])
    df = df.sort_values('date')
    df['perc_change'] = df['adjusted_close'].pct_change()
    return df

def fetch_and_process_stock_data(input_start_days, input_end_days):
    """
    Main function to fetch and process stock data.
    
    Args:
        input_start_days (int): Number of days to look back from end date
        input_end_days (int): Number of days to offset from current date for end date
        
    Returns:
        list: [start_date, end_date]
    """
    # Set up logging
    logging.basicConfig(level=logging.INFO, 
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler("script_log.txt"),
                            logging.StreamHandler()
                        ])

    logging.info("Script started")

    # Load environment variables
    load_dotenv()
    API_TOKEN = os.getenv('MY_API_TOKEN')
    logging.info("Environment variables loaded")

    # Database connection details
    STOCK_DB_HOST = os.getenv('STOCK_DB_HOST')
    STOCK_DB_USER = os.getenv('STOCK_DB_USER')
    STOCK_DB_PASSWORD = os.getenv('STOCK_DB_PASSWORD')
    STOCK_DB_NAME = os.getenv('STOCK_DB_NAME')
    logging.info("Database connection details retrieved")

    # Calculate date range
    end_date = datetime.now().date() - timedelta(days=input_end_days)
    while end_date.weekday() > 4:  # Adjust to last weekday
        end_date -= timedelta(days=1)
    start_date = end_date - timedelta(days=input_start_days)
    
    # Store these for return values
    output_start_date = start_date
    output_end_date = end_date
    
    logging.info(f"Date range calculated: {start_date} to {end_date}")

    try:
        connection = mysql.connector.connect(
            host=STOCK_DB_HOST,
            user=STOCK_DB_USER,
            password=STOCK_DB_PASSWORD,
            database=STOCK_DB_NAME
        )
        cursor = connection.cursor()
        logging.info("Connected to STOCKNAME database")

        # Ensure eod_data table exists
        create_eod_table_query = """
        CREATE TABLE IF NOT EXISTS eod_data (
            TradeDate DATE,
            StockName VARCHAR(255),
            Open DECIMAL(10,2),
            High DECIMAL(10,2),
            Low DECIMAL(10,2),
            Close DECIMAL(10,2),
            AdjustedClose DECIMAL(20,6),
            Volume BIGINT,
            PRIMARY KEY (TradeDate, StockName)
        ) ENGINE=InnoDB
        """
        cursor.execute(create_eod_table_query)
        logging.info("Ensured eod_data table exists")

        # Fetch stock names
        stock_names = fetch_stock_names(cursor)
        logging.info(f"Fetched {len(stock_names)} stock names")

        # Fetch data and calculate percentage change for each stock
        stock_data = {}
        for i, stock in enumerate(stock_names, 1):
            logging.info(f"Processing stock {i}/{len(stock_names)}: {stock}")
            data = fetch_stock_data(stock, start_date, end_date, API_TOKEN)
            if data:
                # Insert into eod_data table
                insert_eod_data(connection, cursor, stock, data)
                # Calculate percentage change
                stock_data[stock] = calculate_percentage_change(data)
            logging.info(f"Completed processing for {stock}")

        logging.info("Constructing complete DataFrame of percentage changes")

        # Create a date range for all possible trading days
        all_dates = pd.date_range(start=start_date, end=end_date)

        # Initialize an empty DataFrame with dates as index
        complete_df = pd.DataFrame(index=all_dates)

        # Fill the DataFrame with percentage changes for each stock
        for stock in stock_names:
            if stock in stock_data:
                df = stock_data[stock]
                complete_df[stock] = df.set_index('date')['perc_change']

        # Sort the DataFrame by date
        complete_df.sort_index(inplace=True)

        # Remove rows where all values are NaN (non-trading days)
        complete_df.dropna(how='all', inplace=True)

        # Fill NaN values with 0 (for days when a particular stock didn't trade)
        complete_df.fillna(0, inplace=True)

        # Drop the percentage_changes table if it exists
        cursor.execute("DROP TABLE IF EXISTS percentage_changes")
        logging.info("Dropped existing percentage_changes table")

        # Create the new percentage_changes table
        create_table_query = """
        CREATE TABLE percentage_changes (
            TradeDate DATE,
            {}
        ) ENGINE=InnoDB
        """.format(", ".join([f"`{col}` FLOAT" for col in complete_df.columns]))

        cursor.execute(create_table_query)
        logging.info("Created new percentage_changes table")

        # Prepare the data for insertion
        columns = ['TradeDate'] + list(complete_df.columns)
        placeholders = ', '.join(['%s'] * len(columns))
        
        insert_query = f"""
        INSERT INTO percentage_changes ({', '.join(columns)})
        VALUES ({placeholders})
        """

        # Convert DataFrame to list of tuples for insertion
        data_to_insert = [tuple([index.date()] + list(row)) for index, row in complete_df.iterrows()]

        # Insert the data
        cursor.executemany(insert_query, data_to_insert)
        connection.commit()
        logging.info(f"Inserted {len(data_to_insert)} rows into percentage_changes table")

        # Save the complete DataFrame to a CSV file
        csv_filename = 'stock_percentage_changes.csv'
        complete_df.to_csv(csv_filename)
        logging.info(f"Complete DataFrame saved to {csv_filename}")

    except mysql.connector.Error as error:
        logging.error(f"Error with database operations: {error}")

    finally:
        if connection.is_connected():
            cursor.close()
            connection.close()
            logging.info("STOCKNAME database connection closed")

    logging.info("Script completed")
    
    return [output_start_date, output_end_date]
