import requests
import base64
import json

def encode_image(image_path):
    """Helper function to encode image files to base64."""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def inpaint():
    """
    Performs image inpainting using the GetImg.ai API.
    Uses original_input.png and mask_input.png as inputs,
    and saves the result as prediction_stocks_API.png.
    """
    try:
        # Encode both images
        image = encode_image("original_input.png")
        mask_image = encode_image("mask_input.png")

        # API endpoint and headers
        url = "https://api.getimg.ai/v1/stable-diffusion/inpaint"
        
        headers = {
            "Authorization": "Bearer key-3JbcgLbOZ0cjXKIAQ4N4QM84fPlIynJ3SHlqqrcGXuO2YqA36rDDJoW5ZnPGLWd8vuABuWaH9S896vvkWedt1K8xr2TxIiWP",
            "Content-Type": "application/json"
        }

        # Prepare the payload
        payload = {
            "image": image,
            "mask_image": mask_image,
            "model": "stable-diffusion-v1-5-inpainting",
            "prompt": "continue this pixel pattern, use only pixels with a single color that belongs to this colormap",
            "negative_prompt": "artifacts, blurry, distorted, stretch, charts, graphs, documents",
            "steps": 100,
            "cfg_scale": 7,
            "width": 768,
            "height": 256,
            "output_format": "png",
            "scheduler": "euler_a"
        }

        # Make the API request
        response = requests.post(url, headers=headers, json=payload)
        
        if response.status_code == 200:
            # Get the image data from response
            result = response.json()
            image_bytes = base64.b64decode(result['image'])
            
            # Save the result
            with open("prediction_stocks_API.png", "wb") as f:
                f.write(image_bytes)
        else:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

    except Exception as e:
        raise Exception(f"Inpainting failed: {str(e)}")

if __name__ == "__main__":
    inpaint()

#   NO INPUT NO OUTPUT
#   ONLY FILES
#
#   input:
#       original_input.png
#       mask_input.png
#   output:
#       prediction_stocks_API.png
#       