import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import mysql.connector
from datetime import datetime
from dotenv import load_dotenv
import os

# Corrected list of stock names
stock_names = [
    'AAPL', 'NVDA', 'MSFT', 'GOOGL', 'AMZN', 'META', 'TSM', 'TSLA', 'AVGO', 'LLY',
    'WMT', 'JPM', 'V', 'XOM', 'UNH', 'TCEHY', 'NVO', 'ORCL', 'MA', 'HD',
    'JNJ', 'PG', 'COST', 'ABBV', 'BAC', 'NFLX', 'KO', 'ASML', 'CRM', 'SAP',
    'CVX', 'MRK', 'TMUS', 'AMD', 'BABA', 'AZN', 'PEP', 'NVS', 'TM', 'ACN',
    'CSCO', 'WFC', 'TMO', 'IBM', 'MCD', 'ADBE', 'BX', 'PM', 'ABT', 'GE',
    'QCOM', 'NOW', 'TXN', 'MS', 'AXP', 'CAT', 'DHR', 'ISRG', 'VZ', 'RY',
    'NEE', 'PDD', 'DIS', 'INTU', 'AMGN', 'GS', 'HDB', 'PFE', 'UBER', 'CMCSA',
    'T', 'HSBC', 'UL', 'LOW', 'AMAT', 'SPGI', 'TTE', 'BLK', 'UNP', 'BKNG',
    'BHP', 'HON', 'PGR', 'SYK', 'ETN', 'SNY', 'SCHW', 'LMT', 'KKR', 'BSX',
    'TJX', 'BUD', 'ANET', 'VRTX', 'COP', 'C', 'MDT', 'PANW', 'MU', 'NKE',
    'ADP', 'CB', 'ADI', 'DE', 'PLD', 'SBUX', 'UPS', 'GILD', 'MMC', 'IBN',
    'SONY', 'BMY', 'RIO', 'UBS', 'MELI', 'AMT', 'HCA', 'REGN', 'SHOP', 'PLTR',
    'SO'
]

def inverse_custom_tanh(y, k=0.01):
    """
    Inverse of the custom tanh function to get back original values.
    """
    scale = np.arctanh(0.5) / k
    return np.arctanh(y) / scale

def convert_from_viridis(image_path, predict_column=744, value_min=-0.32, value_max=0.51, offset_height=121):
    """
    Convert pixel colors from a specific column back to their decimal values using VIRIDIS scale.
    """
    img = Image.open(image_path)
    img_array = np.array(img)
    
    height, width, _ = img_array.shape
    column_index = predict_column
    start_row = height - offset_height
    
    target_pixels = img_array[start_row:, column_index]
    reversed_target_pixels = target_pixels[::-1]
    
    N = 256
    viridis = plt.cm.viridis
    viridis_colors = viridis(np.linspace(0, 1, N))[:, :3]
    
    def find_closest_value(pixel_color):
        pixel_color = pixel_color / 255.0
        distances = np.sqrt(np.sum((viridis_colors - pixel_color) ** 2, axis=1))
        closest_idx = np.argmin(distances)
        return closest_idx / (N - 1)
    
    normalized_values = np.array([find_closest_value(pixel) for pixel in reversed_target_pixels])
    intermediate_values = normalized_values * (value_max - value_min) + value_min
    original_values = inverse_custom_tanh(intermediate_values)
    
    return original_values

if __name__ == "__main__":
    # INPUT
    #   image_path
    #   predict_column = 744
    #   value_min = 
    #   value_max = 
    #   offset_height = 
    # OUTPUT
    #   saving to predicted_percentages table
    #
    # Extract values from image
    image_path = "prediction_stocks_API.png"
    # image_path, predict_column=744, value_min=-0.32, value_max=0.51, offset_height=121
    #values = convert_from_viridis(image_path)
    values = convert_from_viridis(image_path, predict_column=744, value_min=-0.32, value_max=0.51, offset_height=121)
    
    # Print information about the values array
    print("\nValues array information:")
    print(f"Shape: {values.shape}")
    print(f"Type: {values.dtype}")
    print(f"Number of elements: {len(values)}")
    print(f"First few values: {values[:5]}")


    # Get current date in YYYY-MM-DD format
    current_date = datetime.now().strftime('%Y-%m-%d')

    # Load environment variables
    load_dotenv()

    # Connect to MySQL database
    db = mysql.connector.connect(
        host=os.getenv('STOCK_DB_HOST'),
        user=os.getenv('STOCK_DB_USER'),
        password=os.getenv('STOCK_DB_PASSWORD'),
        database=os.getenv('STOCK_DB_NAME')
    )

    cursor = db.cursor()


    # Prepare SQL query
    columns = ['TradeDate'] + stock_names[:len(values)]
    placeholders = ', '.join(['CAST(%s AS FLOAT)' if col != 'TradeDate' else '%s' for col in columns])
    update_statements = ', '.join([f"{col} = VALUES({col})" for col in columns if col != 'TradeDate'])
    
    query = f"""
    INSERT INTO predicted_percentage_changes ({', '.join(columns)})
    VALUES ({placeholders})
    ON DUPLICATE KEY UPDATE
    {update_statements}
    """


    # Prepare data for insertion, converting float64 to native Python float
    data = [current_date] + [float(value) for value in values]

    # Execute query
    try:
        cursor.execute(query, data)
        db.commit()
        print("Data inserted successfully into table 'predicted_percentage_changes'")
    except mysql.connector.Error as error:
        print(f"Error inserting data: {error}")
        db.rollback()
    finally:
        cursor.close()
        db.close()

        
    # Create a table header for console and file output
    print("\nStock Prediction:")
    print("-" * 50)
    print(f"{'Stock':<10} | {'Prediction':>12}")
    print("-" * 50)
    
    table_content = []
    for stock, value in zip(stock_names[:len(values)], values):
        table_content.append(f"{stock:<10} | {value:>12.6f}")
    
    # Print to console
    print("\n".join(table_content))
    print("-" * 50)
    print(f"Predicted value: {values[0]:.6f}")
    
    # Save to file
    with open('stock_predictions.txt', 'w') as f:
        f.write("Stock Prediction\n")
        f.write("-" * 50 + "\n")
        f.write(f"{'Stock':<10} | {'Prediction':>12}\n")
        f.write("-" * 50 + "\n")
        f.write("\n".join(table_content))
        f.write("\n" + "-" * 50 + "\n")
        f.write(f"Total stocks: {len(values)}\n")
        f.write(f"Value range: {np.min(values):.6f} to {np.max(values):.6f}\n")
    
    print("\nPrediction has been saved to 'stock_predictions.txt'")
