from modules import fetch_and_process_stock_data
from modules import convert_to_image
from modules import inpaint
import mysql.connector
from datetime import datetime
from dotenv import load_dotenv
import os
import numpy as np
from modules import convert_from_viridis, stock_names
from modules import get_next_working_day

if __name__ == "__main__":
    """  STEP (1)
    Using the EOD API it fetches and stores all the EOD data for all the names stored in the table prediction_names_100B.
    It then calculates the percentage changes according to the Adjusted Close Changes and stores this in the table percentage_changes

    Args:
            input_start_days: going back  e.g. 360 (1y), 360*3 (3y)
            input_end_days:   going back end date e.g. 1 day back
    Output:
            Start Date
            End Date
    """
    input_start_days = 360*3
    input_end_days = 1
    start_date, end_date = fetch_and_process_stock_data(input_start_days, input_end_days)
    print(f"Start Date: {start_date}")
    print(f"End Date: {end_date}")

    """ STEP (2)
    Fetches the percentage changes, looks at max and minimum input values, and scales the data using the TANH function that at +/- 1% the TANH function returns +/-0.5.
    So this makes the TANH function which ranges from -1/+1 spread most of the percentages in the center of the TANH function and more extreme values towards the edges.
    It then proceeds to normalize the data and then encode them in a colormap (VIRIDIS) and saves the image as a pixel file. one pixel one datapoint.
    It creates a precursor to the prediction which is the mask image and the extended image which is extended by white space. The original image is masked in black in the mask image.

    Output:
        global_min : -1.0
        global_max : +1.0
        height : 744
        width : 121
        output_min : -0.351166
        output_max : +0.560601
        prediction_x : 768
        prediction_y : 256
    files to output:
        original_input.png
        mask_input.png
        data_with_labels_and_colorbar.png
        transformed_percentage_changes.png
    """
    global_min, global_max, height, width, output_min, output_max, prediction_x, prediction_y = convert_to_image()
    print(f"Global Min: {global_min}")
    print(f"Global Max: {global_max}")
    print(f"Value Min: {output_min}")
    print(f"Value Max: {output_max}")
    print(f"Height: {height}")
    print(f"Width: {width}")
    print(f"Prediction Height: {prediction_y}")
    print(f"Prediction Width: {prediction_x}")

    """
    maskimage and original image are passed onto AI for prediction step. 

    Args:
        prediction_width (int): Width of the output image. Defaults to 768.
        prediction_height (int): Height of the output image. Defaults to 256.
    files input:
        original_input.png
        mask_input.png
    files output:
        prediction_stocks_API.png
    """
    inpaint(prediction_width=768, prediction_height=256)

    """
    Using the predictied image of the previous step it is converted back to percentage changes.

    Args:
        image path
        predict_column = 744   the column predicted to be fetched (array starts from 0 so N-10)
        value_min = -0.32      the minimum value of percentage changes original
        value_max = 0.56        the maximum value of percentage changes original
        offset_height = 121     the extended image is cut off by this offset to get back the original length of the array
    output:
        percentage changes
        saving to mySQL table called predicted_percentages with the respective TradeDate
    """
    image_path = "prediction_stocks_API.png"
    values = convert_from_viridis(image_path, predict_column=744, value_min=output_min, value_max=output_max, offset_height=121)
    
    # Print information about the values array
    print("\nValues array information:")
    print(f"Shape: {values.shape}")
    print(f"Type: {values.dtype}")
    print(f"Number of elements: {len(values)}")
    print(f"First few values: {values[:5]}")

    # Get current date in YYYY-MM-DD format
    # current_date = datetime.now().strftime('%Y-%m-%d')
    current_date = get_next_working_day()
    print(f"Next working day: {current_date}")

    # Load environment variables
    load_dotenv()

    # Connect to MySQL database
    db = mysql.connector.connect(
        host=os.getenv('STOCK_DB_HOST'),
        user=os.getenv('STOCK_DB_USER'),
        password=os.getenv('STOCK_DB_PASSWORD'),
        database=os.getenv('STOCK_DB_NAME')
    )

    cursor = db.cursor()

    # Prepare SQL query
    columns = ['TradeDate'] + stock_names[:len(values)]
    placeholders = ', '.join(['CAST(%s AS FLOAT)' if col != 'TradeDate' else '%s' for col in columns])
    update_statements = ', '.join([f"{col} = VALUES({col})" for col in columns if col != 'TradeDate'])
    
    query = f"""
    INSERT INTO predicted_percentage_changes ({', '.join(columns)})
    VALUES ({placeholders})
    ON DUPLICATE KEY UPDATE
    {update_statements}
    """

    # Prepare data for insertion, converting float64 to native Python float
    data = [current_date] + [float(value) for value in values]

    # Execute query
    try:
        cursor.execute(query, data)
        db.commit()
        print("Data inserted successfully into table 'predicted_percentage_changes'")
    except mysql.connector.Error as error:
        print(f"Error inserting data: {error}")
        db.rollback()
    finally:
        cursor.close()
        db.close()
        
    # Create a table header for console and file output
    print("\nStock Prediction:")
    print("-" * 50)
    print(f"{'Stock':<10} | {'Prediction':>12}")
    print("-" * 50)
    
    table_content = []
    for stock, value in zip(stock_names[:len(values)], values):
        table_content.append(f"{stock:<10} | {value:>12.6f}")
    
    # Print to console
    print("\n".join(table_content))
    print("-" * 50)
    print(f"Predicted value: {values[0]:.6f}")
    
    # Save to file
    with open('stock_predictions.txt', 'w') as f:
        f.write("Stock Prediction\n")
        f.write("-" * 50 + "\n")
        f.write(f"{'Stock':<10} | {'Prediction':>12}\n")
        f.write("-" * 50 + "\n")
        f.write("\n".join(table_content))
        f.write("\n" + "-" * 50 + "\n")
        f.write(f"Total stocks: {len(values)}\n")
        f.write(f"Value range: {np.min(values):.6f} to {np.max(values):.6f}\n")
    
    print("\nPrediction has been saved to 'stock_predictions.txt'")