import joblib
import logging
import os
from pathlib import Path
import pandas as pd
from scripts.get_prediction_input import get_prediction_input
from scripts.update_data import update_data as update
from scripts.retrain_model import retrain_model
from scripts.model_eval import evaluate_model_predictions
from app.constants import MODELS_FEATURES




MODEL_PATH = Path('models/fpl_model.pkl')
MODELS_DIR = Path('models')
PREDICTIONS_PATH = Path('data/predictions/2526')
PLAYERS_GW_PATH = Path('data/raw/2025-26/gws')
logger = logging.getLogger(__name__)
models = None
model = None
def load_model():
    """
    Load the pre-trained model from the specified path.
    """
    global model
    if model is None:
        try:
            model = joblib.load(MODEL_PATH)
        except FileNotFoundError:
            model = None
        
    return model
def load_other_models():
    """
    Load the pre-trained model from the specified path.
    """
    global models
    if models is None :
        try:
            models = {}

            model_names = [f.stem for f in MODELS_DIR.iterdir() if f.is_file() and f.name != 'fpl_model.pkl']
            print(model_names)
            for name in model_names :
                model_path = MODELS_DIR / f"{name}.pkl"
                models[name] = joblib.load(model_path)
        except FileNotFoundError:
            models = None
        
    return models

def predict(gw:int) -> pd.DataFrame:
    """
    Predict the target variable using the pre-trained model.
    
    :param features: DataFrame containing the features for prediction.
    :return: Series of predictions.
    """
    if model is None:
        raise ValueError("Model is not loaded. Please load the model before making predictions.")
    
    # Ensure the features DataFrame has the same columns as the model was trained on
    # input_data = get_prediction_input(gw)

    # ids = input_data['id']
    # input_data = input_data.drop(columns=['id', 'team',"name",],errors="ignore")
    # model_feature_names = model.feature_names_in_
    # model_features = input_data[MODELS_FEATURES["fpl_model"]].copy()
    logger.info(f"Fetching Prediction Input Data for GW {gw} Started.")

    input_data = get_prediction_input(gw)
    logger.info(f"Fetching Prediction Input Data for GW {gw} completed.")
    ids = input_data['id']
    input_data = input_data.drop(columns=['id', 'team',"name",],errors="ignore")
    model_feature_names = model.feature_names_in_
    missing_features = set(model_feature_names) - set(input_data.columns)
    if missing_features:
        logger.warning(f"Missing features in input data: {missing_features}")
    logger.info(f"Data Preparation for GW {gw} prediction completed.")
    prediction = model.predict(input_data)
    logger.info(f"Main model predictions for GW {gw} completed.")
    #_ = make_other_model_predictions(gw,input_data,ids)

    result = pd.DataFrame({
        'id': ids,
        'xP': prediction
    }).sort_values(by="id")
    result ["xP"] = result["xP"].round(0).astype(int)
    result ["id"] = result["id"].astype(int)

    result.to_csv(PREDICTIONS_PATH / f"gw{gw}_predictions.csv", index=False)
    logger.info(f"Predictions for GW {gw} saved to {PREDICTIONS_PATH / f'gw{gw}_predictions.csv'}")
    return result
def make_other_model_predictions(gw:int,features:pd.DataFrame,ids:pd.Series) -> pd.DataFrame:
    other_models = load_other_models()
    if other_models is None:
        return
    try:
        results = pd.DataFrame({'id': ids}).sort_values(by="id").astype(int)
        for model_name, model in other_models.items():
            if model_name in MODELS_FEATURES:
                model_features = MODELS_FEATURES[model_name]
                missing_features = set(model_features) - set(features.columns)
                if missing_features:
                    logger.warning(f"Missing features for model {model_name}: {missing_features}")
                
                model_input = features[model_features].copy() if model_features else features.copy()
                prediction = model.predict(model_input)
                results['xP'] = prediction
                results["xP"] = results["xP"].round(0).astype(int)
                logger.info(f"Predictions for GW {gw} using model {model_name} completed.")
            else:
                logger.warning(f"No feature list found for model {model_name}. Skipping prediction.")
            results.to_csv(PREDICTIONS_PATH / f"{model_name}/gw{gw}_predictions.csv", index=False)
        logger.info(f"Other model predictions for GW {gw} saved")
        return
    except Exception as e:
        logger.error(f"Error during other model predictions: {e}")
        return 

def update_model_data():
    try:
        update()

    except Exception as e :
        logger.error(f"Error Occured : {e}")
    finally:
        logger.info(f"Evaluation for Model Performance started")
        evaluate_model_predictions()
        logger.info(f"Evaluation for Model Performance completed and saved")
        


def get_player_data(gw):
    players_df = pd.read_csv(PLAYERS_GW_PATH / f"gw{gw}.csv") [["element","total_points"]]
    players_df = players_df.rename(columns={'element': 'id'})
    players_df['id'] = players_df['id'].astype('Int64')
    players_df['total_points'] = players_df['total_points'].astype('Int64')
    players_df = players_df.sort_values(by=['id'])
    return players_df.to_dict(orient='records')
def model_retarin( gw : int) :
    try:
        retrain_model(validation_season="2526",validation_gw=gw // 3,validation_len= min(5,max(1,gw-5)))
    except Exception as e :
        logger.error(f"Error Occured : {e}")

if __name__ == "__main__":
    load_model()
    # Example usage
    gw = 4
    predictions = predict(gw)
    print(predictions.head())