from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import json
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import io
import base64
import requests
import traceback

app = Flask(__name__)
CORS(app)

# Global variables to store values from process endpoint
stored_x_mult = None
stored_y_mult = None
stored_img_width = None
stored_img_height = None
stored_pct_global_min = None
stored_pct_global_max = None
stored_trans_global_min = None
stored_trans_global_max = None
stored_global_min = None
stored_global_max = None
stored_numeric_df = None  # Added global variable for numeric_df

def encode_image(image_path):
    """Helper function to encode image files to base64."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def image_to_base64(img):
    """Helper function to convert PIL Image to base64."""
    img_buffer = io.BytesIO()
    img.save(img_buffer, format='PNG')
    return base64.b64encode(img_buffer.getvalue()).decode()

def save_debug_image(img, filename):
    """Helper function to save an image for debugging."""
    try:
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        img.save(filename)
        print(f"Successfully saved debug image: {filename}")
    except Exception as e:
        print(f"Error saving debug image {filename}: {str(e)}")

def inverse_custom_tanh(y, k=0.01):
    """
    Inverse of the custom tanh function to get back original values.
    """
    scale = np.arctanh(0.5) / k
    return np.arctanh(y) / scale

def process_image_to_values(img_array):
    """
    Convert image pixel colors back to numerical values using the viridis colormap.
    """
    print(f"Using stored percentage change min: {stored_pct_global_min}")
    print(f"Using stored percentage change max: {stored_pct_global_max}")
    
    N = 256
    viridis = plt.cm.viridis
    viridis_colors = viridis(np.linspace(0, 1, N))[:, :3]
    
    def find_closest_value(pixel_color):
        pixel_color = pixel_color / 255.0
        distances = np.sqrt(np.sum((viridis_colors - pixel_color) ** 2, axis=1))
        closest_idx = np.argmin(distances)
        return closest_idx / (N - 1)
    
    # Reshape the array to 2D if it's 3D
    if len(img_array.shape) == 3:
        pixels = img_array.reshape(-1, 3)
    else:
        pixels = img_array
    
    print(f"pixels array shape: {pixels.shape}")
    normalized_values = np.array([find_closest_value(pixel) for pixel in pixels])
    print(f"Normalized array shape: {normalized_values.shape}")
    print(f"Normalized Values: {normalized_values}")
    if stored_pct_global_min is not None and stored_pct_global_max is not None:
        intermediate_values = normalized_values * (stored_pct_global_max - stored_pct_global_min) + stored_pct_global_min
        print(f"Rescaled Values: {intermediate_values[:5]}")
        original_values = inverse_custom_tanh(intermediate_values)
        print(f"Original Values: {original_values[:5]}")
        print(f"Original array shape: {original_values.shape}")
    else:
        print("Warning: No stored percentage change min/max values available, using normalized values directly")
        original_values = normalized_values  # Fallback if global values aren't available
    
    # Reshape back to original dimensions if needed
    if len(img_array.shape) == 3:
        original_values = original_values.reshape(img_array.shape[0], -1)
    
    return original_values

def inpaint():
    """
    Performs image inpainting using the GetImg.ai API.
    Uses original_input.png and mask_input.png as inputs,
    and saves the result as prediction_stocks_API.png.
    """
    try:
        # Encode both images
        image = encode_image("original_input.png")
        mask_image = encode_image("mask_input.png")

        # API endpoint and headers
        url = "https://api.getimg.ai/v1/stable-diffusion/inpaint"
        
        headers = {
            "Authorization": "Bearer key-3JbcgLbOZ0cjXKIAQ4N4QM84fPlIynJ3SHlqqrcGXuO2YqA36rDDJoW5ZnPGLWd8vuABuWaH9S896vvkWedt1K8xr2TxIiWP",
            "Content-Type": "application/json"
        }

        # Prepare the payload
        payload = {
            "image": image,
            "mask_image": mask_image,
            "model": "stable-diffusion-v1-5-inpainting",
            "prompt": "continue this pixel pattern, use only pixels with a single color",
            # "prompt": "continue this pixel pattern, use only pixels with a single color that belongs to this colormap",
            "negative_prompt": "artifacts, blurry, distorted, stretch, charts, graphs, documents",
            "steps": 100,
            "cfg_scale": 7,
            "width": 1024,
            "height": 1024,
            "output_format": "png",
            "scheduler": "euler_a"
        }

        # Make the API request
        response = requests.post(url, headers=headers, json=payload)
        
        if response.status_code == 200:
            # Get the image data from response
            result = response.json()
            image_bytes = base64.b64decode(result['image'])
            
            # Save the result
            with open("prediction_stocks_API.png", "wb") as f:
                f.write(image_bytes)
            
            # Return the base64 encoded prediction image
            return base64.b64encode(image_bytes).decode('utf-8')
        else:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

    except Exception as e:
        raise Exception(f"Inpainting failed: {str(e)}")

def calculate_max_multiple_less_than_1024(n):
    # Find the maximum x where x * n is a multiple of 64 and less than 1024
    x = (1024 - 1) // n  # Start with the maximum possible x
    while x > 0:
        product = x * n
        if product % 64 == 0 and product < 1024:
            return x, product
        x -= 1
    return None, None

def calculate_max_multiple_up_to_1024(n):
    # Find the maximum x where x * n is a multiple of 64 and less than or equal to 1024
    x = 1024 // n  # Start with the maximum possible x
    while x > 0:
        product = x * n
        if product % 64 == 0 and product <= 1024:
            return x, product
        x -= 1
    return None, None

def custom_tanh(x, k=0.01):
    """
    Custom tanh function where ±k input maps to ±0.5 tanh output.
    
    Parameters:
    x (float or array): Input value(s)
    k (float): The input value that should map to 0.5 in the tanh output
    
    Returns:
    float or array: Transformed value(s)
    """
    scale = np.arctanh(0.5) / k
    return np.tanh(scale * x)

def create_heatmap_image(df_transformed, global_min, global_max):
    """
    Create a heatmap image from the transformed dataframe.
    
    Parameters:
    df_transformed (DataFrame): The transformed dataframe
    global_min (float): Global minimum value for normalization
    global_max (float): Global maximum value for normalization
    
    Returns:
    tuple: (base64 encoded image string, original input image string, mask image string, image width, image height, x_mult, y_mult)
    """
    global stored_x_mult, stored_y_mult, stored_img_width, stored_img_height
    
    # Get the dimensions of the data
    height, width = df_transformed.shape
    
    # Calculate multipliers for height and width
    y_mult, new_height = calculate_max_multiple_less_than_1024(height)
    x_mult, new_width = calculate_max_multiple_up_to_1024(width)
    
    if y_mult is None or x_mult is None:
        # Fallback to original dimensions if no valid multipliers found
        y_mult, x_mult = 1, 1
        new_height, new_width = height, width

    # Store values for use in other endpoints
    stored_x_mult = x_mult
    stored_y_mult = y_mult
    stored_img_width = new_width
    stored_img_height = new_height

    # Create a normalization object
    norm = Normalize(vmin=global_min, vmax=global_max)

    # Create an array to hold the RGB values with expanded dimensions
    rgb_array = np.zeros((new_height, new_width, 3), dtype=np.uint8)

    # Fill the array with colors based on normalized values
    for i in range(height):
        for j in range(width):
            value = df_transformed.iloc[i, j]
            color = plt.cm.viridis(norm(value))
            rgb_color = np.array(color[:3]) * 255
            
            # Fill the expanded blocks with the same color
            y_start, y_end = i * y_mult, (i + 1) * y_mult
            x_start, x_end = j * x_mult, (j + 1) * x_mult
            
            rgb_array[y_start:y_end, x_start:x_end] = rgb_color

    # Create figure with labels and colorbar
    dpi = 150
    fig, ax = plt.subplots(figsize=(new_width/dpi, new_height/dpi), dpi=dpi)
    im = ax.imshow(rgb_array, aspect='auto', interpolation='nearest')

    # Set axis labels and ticks
    ax.set_xlabel('Stocks')
    ax.set_ylabel('Trade Date')

    # Calculate tick positions and labels using stored_numeric_df
    if stored_numeric_df is not None:
        x_ticks = [0, width // 2, width - 1]
        y_ticks = [0, height // 2, height - 1]
        ax.set_xticks(x_ticks)
        ax.set_yticks(y_ticks)

        x_labels = [stored_numeric_df.columns[0], stored_numeric_df.columns[width // 2], stored_numeric_df.columns[-1]]
        y_labels = [stored_numeric_df.index[0], stored_numeric_df.index[height // 2], stored_numeric_df.index[-1]]
        ax.set_xticklabels(x_labels, rotation=45, ha='right')
        ax.set_yticklabels(y_labels)

    # Add colorbar
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap='viridis'), ax=ax)
    cbar.set_label('Transformed Percentage Change')

    plt.tight_layout()
    plt.savefig('data_with_labels_and_colorbar.png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    plt.close()

    # Create heatmap image
    img = Image.fromarray(rgb_array, 'RGB')
    
    # Create original_input image
    prediction_x = prediction_y = 1024
    white_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
    rotated_img = img.transpose(Image.ROTATE_90)
    position = (0, prediction_y - rotated_img.height)
    white_background.paste(rotated_img, position)
    
    # Create mask_input image
    mask_background = Image.new('RGB', (prediction_x, prediction_y), color='white')
    black_rectangle = Image.new('RGB', rotated_img.size, color='black')
    mask_background.paste(black_rectangle, position)
    
    # Save images for inpainting
    white_background.save('original_input.png', dpi=(300, 300))
    mask_background.save('mask_input.png', dpi=(300, 300))
    
    # Convert all images to base64 strings
    def img_to_base64(img):
        img_buffer = io.BytesIO()
        img.save(img_buffer, format='PNG')
        return base64.b64encode(img_buffer.getvalue()).decode()
    
    heatmap_str = img_to_base64(img)
    original_input_str = img_to_base64(white_background)
    mask_input_str = img_to_base64(mask_background)
    labeled_heatmap_str = encode_image('data_with_labels_and_colorbar.png')
    
    # Generate prediction using inpainting
    try:
        prediction_str = inpaint()
    except Exception as e:
        prediction_str = None
        print(f"Inpainting error: {str(e)}")
    
    return heatmap_str, original_input_str, mask_input_str, new_width, new_height, x_mult, y_mult, prediction_str, labeled_heatmap_str

@app.route('/')
def health_check():
    return "Flask server is running!"

@app.route('/predicted-region')
def get_predicted_region():
    try:
        # Load the prediction image
        prediction = Image.open('prediction_stocks_API.png')
        
        # Get the dimensions of the original heatmap
        with open('original_input.png', 'rb') as f:
            original = Image.open(f)
            original_array = np.array(original)
            img_height = original_array.shape[0] - np.sum(np.all(original_array == [255, 255, 255], axis=2)[0])
        
        # Extract only the predicted region (right side)
        prediction_array = np.array(prediction)
        predicted_region = prediction_array[:, img_height:, :]
        
        # Create new image from the predicted region
        predicted_region_image = Image.fromarray(predicted_region)
        
        # Convert to base64
        img_buffer = io.BytesIO()
        predicted_region_image.save(img_buffer, format='PNG')
        predicted_region_str = base64.b64encode(img_buffer.getvalue()).decode()
        
        # Return HTML showing the extracted region
        return f"""
        <h2>Predicted Region</h2>
        <div style='display: flex; flex-direction: column; gap: 20px;'>
            <div>
                <h3>Full Prediction Image</h3>
                <img src='data:image/png;base64,{encode_image("prediction_stocks_API.png")}' alt='Full prediction image'>
            </div>
            <div>
                <h3>Extracted Predicted Region</h3>
                <img src='data:image/png;base64,{predicted_region_str}' alt='Predicted region'>
            </div>
        </div>
        """
    except Exception as e:
        return f"Error extracting predicted region: {str(e)}", 500

@app.route('/reduced-predicted-region')
def get_reduced_predicted_region():
    try:
        print("Starting reduced predicted region processing...")
        
        # Load the prediction image
        prediction = Image.open('prediction_stocks_API.png')
        print(f"Loaded prediction image, size: {prediction.size}")
        
        # Get the dimensions of the original heatmap
        with open('original_input.png', 'rb') as f:
            original = Image.open(f)
            original_array = np.array(original)
            img_height = original_array.shape[0] - np.sum(np.all(original_array == [255, 255, 255], axis=2)[0])
            print(f"Original image height: {img_height}")
        
        # Extract only the predicted region (right side)
        prediction_array = np.array(prediction)
        print(f"Prediction array shape: {prediction_array.shape}")
        
        predicted_region = prediction_array[:, img_height:, :]
        print(f"Extracted region shape: {predicted_region.shape}")
        
        predicted_region_image = Image.fromarray(predicted_region)
        save_debug_image(predicted_region_image, 'debug_original_region.png')
        
        # Rotate 90 degrees clockwise
        rotated_region = predicted_region_image.rotate(-90, expand=True)
        rotated_array = np.array(rotated_region)
        print(f"Rotated array shape: {rotated_array.shape}")
        save_debug_image(rotated_region, 'debug_rotated_region.png')
        
        # Use the stored x_mult from the process endpoint
        if stored_x_mult is None:
            x_mult = 64  # fallback value
        else:
            x_mult = stored_x_mult
        print(f"Using x_mult: {x_mult}")
        
        # Take first column and every nth column where n is multiple of x_mult
        height, width, _ = rotated_array.shape
        selected_columns = []
        
        # Add first column
        selected_columns.append(rotated_array[:, 0:1, :])
        print(f"Added first column, shape: {rotated_array[:, 0:1, :].shape}")
        
        # Add columns at multiples of x_mult
        for i in range(x_mult, width, x_mult):
            if i <= width:  # ensure we don't exceed the image width
                selected_columns.append(rotated_array[:, i:i+1, :])
                print(f"Added column at position {i}")
        
        # Combine the selected columns
        reduced_array = np.concatenate(selected_columns, axis=1)
        print(f"Reduced array shape: {reduced_array.shape}")
        
        # Create new image from reduced array
        reduced_image = Image.fromarray(reduced_array)
        save_debug_image(reduced_image, 'debug_reduced_region.png')
        
        # Process the reduced image to get numerical values
        numerical_values = process_image_to_values(reduced_array)
        print(f"Numerical values shape: {numerical_values.shape}")
        print(f"Sample of numerical values: {numerical_values[:5, :5]}")
        
        # Format numerical values for display
        numerical_values_html = "<div style='overflow-x: auto;'><table style='border-collapse: collapse;'>"
        for row in numerical_values:
            numerical_values_html += "<tr>"
            for val in row:
                numerical_values_html += f"<td style='border: 1px solid #ddd; padding: 8px;'>{val:.4f}</td>"
            numerical_values_html += "</tr>"
        numerical_values_html += "</table></div>"
        
        # Convert images to base64
        predicted_region_str = image_to_base64(predicted_region_image)
        reduced_region_str = image_to_base64(reduced_image)
        
        print("Successfully processed reduced predicted region")
        
        # Return HTML showing both the original and reduced regions, plus numerical values
        return f"""
        <h2>Reduced Predicted Region</h2>
        <div style='display: flex; flex-direction: column; gap: 20px;'>
            <div>
                <h3>Original Extracted Region</h3>
                <img src='data:image/png;base64,{predicted_region_str}' alt='Original extracted region'>
            </div>
            <div>
                <h3>Reduced Region (90° rotated, selected columns)</h3>
                <img src='data:image/png;base64,{reduced_region_str}' alt='Reduced region'>
            </div>
            <div>
                <h3>Numerical Values (Processed from Reduced Region)</h3>
                {numerical_values_html}
            </div>
        </div>
        """
    except Exception as e:
        print(f"Error in reduced predicted region: {str(e)}")
        print(traceback.format_exc())
        return f"Error creating reduced predicted region: {str(e)}", 500

@app.route('/process', methods=['POST'])
def process_data():
    global stored_pct_global_min, stored_pct_global_max, stored_trans_global_min
    global stored_trans_global_max, stored_global_min, stored_global_max
    global stored_numeric_df  # Added global declaration for numeric_df
    
    data = request.json['data']
    try:
        parsed_data = json.loads(data)
        formatted_data = json.dumps(parsed_data, indent=2)
        
        # Calculate dimensions
        if isinstance(parsed_data, dict):
            dimensions = f"Number of top-level keys: {len(parsed_data)}"
        elif isinstance(parsed_data, list):
            dimensions = f"Array length: {len(parsed_data)}"
        else:
            dimensions = "Data is a single value"
        
        # Convert to DataFrame
        df_info = ""
        try:
            df = pd.DataFrame(parsed_data)
            
            # Ensure columns are flat (not multi-index)
            if isinstance(df.columns, pd.MultiIndex):
                df.columns = df.columns.get_level_values(-1)
            
            # Set TradeDate or Date as index if they exist
            if 'TradeDate' in df.columns:
                df.set_index('TradeDate', inplace=True)
            elif 'Date' in df.columns:
                df.set_index('Date', inplace=True)
            
            # Calculate global min and max across all numeric columns
            numeric_df = df.select_dtypes(include=['int64', 'float64'])
            stored_numeric_df = numeric_df  # Store numeric_df globally
            global_min = numeric_df.min().min()
            global_max = numeric_df.max().max()
            
            # Store global min/max
            stored_global_min = global_min
            stored_global_max = global_max
            
            # Calculate percentage changes and fill NaN values with 0
            pct_change_df = df.select_dtypes(include=['int64', 'float64']).pct_change().fillna(0)
            
            # Store percentage change min/max
            stored_pct_global_min = pct_change_df.min().min()
            stored_pct_global_max = pct_change_df.max().max()
            
            # Apply custom_tanh transformation to percentage changes
            transformed_df = pct_change_df.apply(custom_tanh)
            
            # Store transformed min/max
            stored_trans_global_min = transformed_df.min().min()
            stored_trans_global_max = transformed_df.max().max()
            
            # Create heatmap image and additional images from transformed data
            heatmap_img, original_input_str, mask_input_str, img_width, img_height, x_mult, y_mult, prediction_img, labeled_heatmap_str = create_heatmap_image(
                transformed_df, 
                transformed_df.min().min(), 
                transformed_df.max().max()
            )
            
            df_head = df.head().to_html()
            pct_change_head = pct_change_df.head().to_html()
            transformed_head = transformed_df.head().to_html()
            df_shape = f"DataFrame dimensions: {df.shape[0]} rows × {df.shape[1]} columns"
            min_max_info = f"<h3>Original Data - Global Min and Max Values:</h3><p>Minimum value across all data: {global_min}<br>Maximum value across all data: {global_max}</p>"
            pct_min_max_info = f"<h3>Percentage Changes - Global Min and Max Values:</h3><p>Minimum percentage change: {stored_pct_global_min:.4f}<br>Maximum percentage change: {stored_pct_global_max:.4f}</p>"
            trans_min_max_info = f"<h3>Transformed Values - Global Min and Max Values:</h3><p>Minimum transformed value: {stored_trans_global_min:.4f}<br>Maximum transformed value: {stored_trans_global_max:.4f}</p>"
            heatmap_html = f"<h3>Heatmap Visualization:</h3><p>Image dimensions: {img_width} × {img_height} pixels<br>X multiplier: {x_mult}, Y multiplier: {y_mult}</p><img src='data:image/png;base64,{heatmap_img}' alt='Heatmap of transformed data'><p>Heatmap with Labels and Colorbar:</p><img src='data:image/png;base64,{labeled_heatmap_str}' alt='Labeled heatmap with colorbar'>"
            additional_images_html = f"<h3>Additional Visualizations:</h3><p>Original Input Image:</p><img src='data:image/png;base64,{original_input_str}' alt='Original input image'><p>Mask Input Image:</p><img src='data:image/png;base64,{mask_input_str}' alt='Mask input image'>"
            
            # Add prediction image if available
            if prediction_img:
                prediction_html = f"<p>Prediction Image:</p><img src='data:image/png;base64,{prediction_img}' alt='Prediction image'>"
                additional_images_html += prediction_html
            
            df_info = f"<h3>Original DataFrame:</h3>{df_head}<h3>Percentage Changes DataFrame (NaN values replaced with 0):</h3>{pct_change_head}<h3>Transformed DataFrame (using custom_tanh):</h3>{transformed_head}<p>{df_shape}</p>{min_max_info}{pct_min_max_info}{trans_min_max_info}{heatmap_html}{additional_images_html}"
        except Exception as e:
            df_info = f"<p>Note: Data could not be converted to a DataFrame. Error: {str(e)}</p>"
        
        result = f"<h2>Processed Data:</h2><pre>{formatted_data}</pre><h3>Dimensions:</h3><p>{dimensions}</p>{df_info}"
        
        return result
    except json.JSONDecodeError:
        return "Invalid JSON data", 400

if __name__ == '__main__':
    app.run(host='0.0.0.0', debug=True)
