import os
import pandas as pd
import mysql.connector
import numpy as np
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from matplotlib.colors import Normalize
from PIL import Image

def transform_data(df):
    """
    Apply the transformation y = tanh(10.986122886681096 * x) to all values in the dataframe.
    """
    return np.tanh(10.986122886681096 * df)

def custom_tanh(x, k):
    """
    Custom tanh function where ±k input maps to ±0.5 tanh output.
    
    Parameters:
    x (float or array): Input value(s)
    k (float): The input value that should map to 0.5 in the tanh output
    
    Returns:
    float or array: Transformed value(s)
    """
    scale = np.arctanh(0.5) / k
    return np.tanh(scale * x)

def convert_to_image(tanh_convert_constant=0.01, input_x_multiple=12, input_y_multiple=4):
    """
    Converts stock data to images and returns key metrics.
    
    Parameters:
    tanh_convert_constant (float): The constant used in tanh conversion. Default is 0.01.
    
    Returns:
    list: [global_min, global_max, height, width]
    """
    # Load environment variables
    load_dotenv()

    # Database connection details
    db_user = os.getenv('DB_USER')
    db_password = os.getenv('DB_PASSWORD')
    db_host = os.getenv('DB_HOST')
    db_name = 'stocks'

    try:
        # Create database connection
        cnx = mysql.connector.connect(
            user=db_user,
            password=db_password,
            host=db_host,
            database=db_name
        )
        cursor = cnx.cursor()

        # Fetch data from the database
        query = "SELECT * FROM percentage_changes"
        cursor.execute(query)

        # Fetch all rows and column names
        rows = cursor.fetchall()
        columns = [i[0] for i in cursor.description]

        # Create DataFrame
        df = pd.DataFrame(rows, columns=columns)
        df.set_index('TradeDate', inplace=True)

        # Find  min and max of transformed data
        output_min = df.values.min()
        output_max = df.values.max()

        # Apply the transformation using the provided constant
        df_transformed = custom_tanh(df, tanh_convert_constant)

        # Find global min and max of transformed data
        global_min = df_transformed.values.min()
        global_max = df_transformed.values.max()

        # Get the dimensions of the data
        height, width = df_transformed.shape

        # Create a normalization object
        norm = Normalize(vmin=global_min, vmax=global_max)

        # Create an array to hold the RGB values
        rgb_array = np.zeros((height, width, 3), dtype=np.uint8)

        # Fill the array with colors based on normalized values
        for i in range(height):
            for j in range(width):
                value = df_transformed.iloc[i, j]
                color = plt.cm.viridis(norm(value))
                rgb_array[i, j] = np.array(color[:3]) * 255

        # Create and save images
        img = Image.fromarray(rgb_array, 'RGB')
        img.save('transformed_percentage_changes.png')

        # Create figure with labels and colorbar
        dpi = 150
        fig, ax = plt.subplots(figsize=(width/dpi, height/dpi), dpi=dpi)
        im = ax.imshow(rgb_array, aspect='auto', interpolation='nearest')

        # Set axis labels and ticks
        ax.set_xlabel('Stocks')
        ax.set_ylabel('Trade Date')

        x_ticks = [0, width // 2, width - 1]
        y_ticks = [0, height // 2, height - 1]
        ax.set_xticks(x_ticks)
        ax.set_yticks(y_ticks)

        x_labels = [df.columns[0], df.columns[width // 2], df.columns[-1]]
        y_labels = [df.index[0], df.index[height // 2], df.index[-1]]
        ax.set_xticklabels(x_labels, rotation=45, ha='right')
        ax.set_yticklabels(y_labels)

        # Add colorbar
        cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap='viridis'), ax=ax)
        cbar.set_label('Transformed Percentage Change')

        plt.tight_layout()
        plt.savefig('data_with_labels_and_colorbar.png', dpi=dpi, bbox_inches='tight', pad_inches=0)

        prediction_x = 64 * input_x_multiple
        prediction_y = 64 * input_y_multiple

        # Create original_input.png
        white_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
        rotated_img = img.transpose(Image.ROTATE_90)
        position = (0, 256 - rotated_img.height)
        white_background.paste(rotated_img, position)
        white_background.save('original_input.png', dpi=(300, 300))

        # Create mask_input.png
        mask_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
        black_rectangle = Image.new('RGB', rotated_img.size, color='black')
        mask_background.paste(black_rectangle, position)
        mask_background.save('mask_input.png', dpi=(300, 300))

        return [global_min, global_max, height, width, output_min, output_max, prediction_x, prediction_y]

    except mysql.connector.Error as err:
        raise Exception(f"Database error: {err}")

    finally:
        if 'cursor' in locals():
            cursor.close()
        if 'cnx' in locals():
            cnx.close()

if __name__ == "__main__":
    # Example usage with default tanh_convert_constant
    global_min, global_max, height, width, output_min, output_max,prediction_x, prediction_y = convert_to_image()
    print(f"Global Min: {global_min}")
    print(f"Global Max: {global_max}")
    print(f"Value Min: {output_min}")
    print(f"Value Max: {output_max}")
    print(f"Height: {height}")
    print(f"Width: {width}")
    print(f"Prediction Height: {prediction_y}")
    print(f"Prediction Width: {prediction_x}")

    # Example usage with custom tanh_convert_constant
    # global_min, global_max, height, width = convert_to_image(tanh_convert_constant=0.02)

    # OUTPUT:
    #   global_min
    #   global_max
    #   height
    #   width
    #   output_min
    #   output_max
    #   prediction_x dimension
    #   prediction_y dimension
    #   files to output:
    #       original_input.png
    #       mask_input.png
    #       data_with_labels_and_colorbar.png
    #       transformed_percentage_changes.png
