import math
from pathlib import Path
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error,mean_squared_error,root_mean_squared_error
from scripts.data_preprocess import preprocess_data_new as preprocess_data
from scripts.prepare_data import prepare_data
import logging

MODEL_DIR = Path('models/')

logger = logging.getLogger(__name__)

# function to calculate root mean squared error for preds and targs
def r_mse(pred,y): return round(math.sqrt(((pred-y)**2).mean()), 6)

# function to calculate mean absolute error for preds and targs
def mae(pred, y):  return round(abs(pred-y).mean(), 6)

def retrain_model(validation_season: str = '2526', validation_gw: int = 1, validation_len: int = 1):
    """
    Retrain the model with the latest data and save it.
    """
    logger.info("Preparing Data for Retrainig the Model")
    prepare_data()
    # Preprocess the data
    logger.info("Preprocessing Data for Retrainig the Model")

    X_train, y_train, X_test, y_test = preprocess_data(validation_season, validation_gw, validation_len)
    
    #Initialize and train the model
    logger.info("Initializing and Fitting the Model")
    # model = RandomForestRegressor(n_jobs=-1, max_depth=9, oob_score=True,min_samples_leaf=5,max_features=0.5,n_estimators=250)
    model = RandomForestRegressor(n_jobs=-1, max_depth=5, oob_score=True,min_samples_leaf=10,max_features=0.5,n_estimators=350)
    model.fit(X_train, y_train.values.ravel())

    # Evaluate the model
    logger.info("Evaluationg The Model ")
    y_pred = model.predict(X_test)
    mae = mean_absolute_error(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    rmse = root_mean_squared_error(y_test, y_pred)
    logger.info(f"Model Evaluation Complete => MAE: {mae}, MSE: {mse}, RMSE: {rmse}")

    # Save the trained model
    joblib.dump(model, MODEL_DIR / 'fpl_model.pkl')
    logger.info(f"Saved the new Trained Model")


if __name__ == "__main__":
    retrain_model()
    print("Model retraining complete.")