import joblib
import logging
import os
from pathlib import Path
import pandas as pd
from scripts.get_prediction_input import get_prediction_input
from scripts.update_data import update_data as update


MODEL_PATH = Path('models/fpl_model.pkl')
PREDICTIONS_PATH = Path('data/predictions/2526')
logger = logging.getLogger(__name__)
model = None
def load_model():
    """
    Load the pre-trained model from the specified path.
    """
    global model
    if model is None:
        try:
            model = joblib.load(MODEL_PATH)
        except FileNotFoundError:
            model = None
        
    return model

def predict(gw:int) -> pd.DataFrame:
    """
    Predict the target variable using the pre-trained model.
    
    :param features: DataFrame containing the features for prediction.
    :return: Series of predictions.
    """
    if model is None:
        raise ValueError("Model is not loaded. Please load the model before making predictions.")
    
    # Ensure the features DataFrame has the same columns as the model was trained on
    input_data = get_prediction_input(gw)
    ids = input_data['id']
    input_data = input_data.drop(columns=['id', 'team',"name",],errors="ignore")
    model_feature_names = model.feature_names_in_
    missing_features = set(model_feature_names) - set(input_data.columns)
    if missing_features:
        logger.warning(f"Missing features in input data: {missing_features}")
    logger.info(f"Data Preparation for GW {gw} prediction completed.")
    prediction = model.predict(input_data)
    result = pd.DataFrame({
        'id': ids,
        'xP': prediction
    }).sort_values(by="id")
    result ["xP"] = result["xP"].round(0).astype(int)
    result ["id"] = result["id"].astype(int)

    result.to_csv(PREDICTIONS_PATH / f"gw{gw}_predictions.csv", index=False)
    logger.info(f"Predictions for GW {gw} saved to {PREDICTIONS_PATH / f'gw{gw}_predictions.csv'}")
    return result

def update_model_data():
    try:
        update()
    except Exception as e :
        logger.error(f"Error Occured : {e}")

def model_retarin():
    pass