"""
Store Sales Time Series Forecasting - ADVANCED V2
Author: vincenzorubino
Target: TOP 10% Leaderboard

Techniques:
- Zero forecasting for never-sold products
- Store opening date handling
- Oil price interpolation
- Holiday/Event processing (National/Regional/Local)
- Earthquake impact features
- Wage day features (15th and month end)
- PACF-based lag features (16, 30, 45, 365)
- SMA & EWM features
- LightGBM with Tweedie objective
- Ensemble with XGBoost and CatBoost
"""

import numpy as np
import pandas as pd
import lightgbm as lgb
import xgboost as xgb
from catboost import CatBoostRegressor
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_squared_log_error
import warnings
import gc
warnings.filterwarnings('ignore')

print("="*70)
print("   Store Sales - ADVANCED V2 - Target: TOP 10%")
print("="*70)

# ============================================================================
# 1. LOAD DATA
# ============================================================================
print("\n[1/10] Loading data...")

train = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/train.csv',
                    parse_dates=['date'], dtype={'store_nbr': 'int8', 'onpromotion': 'float32'})
test = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/test.csv',
                   parse_dates=['date'], dtype={'store_nbr': 'int8', 'onpromotion': 'float32'})
stores = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/stores.csv')
oil = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/oil.csv', parse_dates=['date'])
holidays = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/holidays_events.csv', parse_dates=['date'])
transactions = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/transactions.csv', parse_dates=['date'])

train['sales'] = train['sales'].astype('float32')
print(f"  Train: {train.shape}, Test: {test.shape}")

# ============================================================================
# 2. STORE OPENING DATES - Remove data before stores opened
# ============================================================================
print("\n[2/10] Cleaning store opening data...")

store_open_dates = {
    52: '2017-04-20', 22: '2015-10-09', 42: '2015-08-21',
    21: '2015-07-24', 29: '2015-03-20', 20: '2015-02-13',
    53: '2014-05-29', 36: '2013-05-09'
}

original_len = len(train)
for store, date in store_open_dates.items():
    train = train[~((train.store_nbr == store) & (train.date < date))]
print(f"  Removed {original_len - len(train)} rows before store openings")

# ============================================================================
# 3. ZERO FORECASTING - Products never sold
# ============================================================================
print("\n[3/10] Identifying zero-sale products...")

zero_sales = train.groupby(['store_nbr', 'family']).sales.sum().reset_index()
zero_sales = zero_sales[zero_sales.sales == 0][['store_nbr', 'family']]

# Remove from training
train = train.merge(zero_sales, on=['store_nbr', 'family'], how='left', indicator=True)
train = train[train._merge == 'left_only'].drop('_merge', axis=1)

# Create zero predictions for these
zero_prediction = []
for _, row in zero_sales.iterrows():
    zero_prediction.append(pd.DataFrame({
        'date': pd.date_range('2017-08-16', '2017-08-31'),
        'store_nbr': row['store_nbr'],
        'family': row['family'],
        'sales': 0.0
    }))
if zero_prediction:
    zero_prediction = pd.concat(zero_prediction)
    print(f"  Found {len(zero_sales)} zero-sale store/family combinations")
else:
    zero_prediction = pd.DataFrame()

# ============================================================================
# 4. OIL PRICE INTERPOLATION
# ============================================================================
print("\n[4/10] Processing oil prices...")

oil = oil.set_index('date').resample('D').mean().reset_index()
oil['dcoilwtico'] = oil['dcoilwtico'].interpolate(method='linear')
oil['dcoilwtico'] = oil['dcoilwtico'].fillna(method='bfill').fillna(method='ffill')
oil.columns = ['date', 'oil_price']

# Oil price features
oil['oil_ma7'] = oil['oil_price'].rolling(7).mean()
oil['oil_ma30'] = oil['oil_price'].rolling(30).mean()
oil['oil_price_high'] = (oil['oil_price'] > 70).astype('int8')

# ============================================================================
# 5. HOLIDAYS PROCESSING
# ============================================================================
print("\n[5/10] Processing holidays and events...")

# Handle transferred holidays
transferred = holidays[(holidays.type == 'Holiday') & (holidays.transferred == True)]
transfer_dates = holidays[holidays.type == 'Transfer']

# Clean holidays
holidays = holidays[(holidays.transferred == False) & (holidays.type != 'Transfer')]

# National holidays
national = holidays[holidays.locale == 'National'][['date', 'description']].drop_duplicates()
national.columns = ['date', 'holiday_national']
national['is_national_holiday'] = 1

# Regional holidays
regional = holidays[holidays.locale == 'Regional'][['date', 'locale_name', 'description']].drop_duplicates()
regional.columns = ['date', 'state', 'holiday_regional']

# Local holidays
local = holidays[holidays.locale == 'Local'][['date', 'locale_name', 'description']].drop_duplicates()
local.columns = ['date', 'city', 'holiday_local']

# Events
events = holidays[holidays.type == 'Event'][['date', 'description']].drop_duplicates()
events.columns = ['date', 'event']

# Earthquake
events['is_earthquake'] = events['event'].str.contains('Terremoto', case=False, na=False).astype('int8')
events['is_futbol'] = events['event'].str.contains('futbol|fútbol', case=False, na=False).astype('int8')

# ============================================================================
# 6. MERGE ALL DATA
# ============================================================================
print("\n[6/10] Merging all features...")

# Combine train and test
train['is_train'] = 1
test['is_train'] = 0
test['sales'] = np.nan
df = pd.concat([train, test], ignore_index=True)

# Merge stores
df = df.merge(stores, on='store_nbr', how='left')

# Merge oil
df = df.merge(oil, on='date', how='left')
df['oil_price'] = df['oil_price'].fillna(method='ffill').fillna(method='bfill')
df['oil_ma7'] = df['oil_ma7'].fillna(df['oil_price'])
df['oil_ma30'] = df['oil_ma30'].fillna(df['oil_price'])
df['oil_price_high'] = df['oil_price_high'].fillna(0)

# Merge holidays
df = df.merge(national[['date', 'is_national_holiday']], on='date', how='left')
df = df.merge(regional, on=['date', 'state'], how='left')
df = df.merge(local, on=['date', 'city'], how='left')
df = df.merge(events[['date', 'is_earthquake', 'is_futbol']], on='date', how='left')

# Fill NaN
df['is_national_holiday'] = df['is_national_holiday'].fillna(0).astype('int8')
df['is_earthquake'] = df['is_earthquake'].fillna(0).astype('int8')
df['is_futbol'] = df['is_futbol'].fillna(0).astype('int8')
df['is_regional_holiday'] = df['holiday_regional'].notna().astype('int8')
df['is_local_holiday'] = df['holiday_local'].notna().astype('int8')

# ============================================================================
# 7. TIME FEATURES
# ============================================================================
print("\n[7/10] Creating time features...")

df['year'] = df.date.dt.year.astype('int16')
df['month'] = df.date.dt.month.astype('int8')
df['day'] = df.date.dt.day.astype('int8')
df['dayofweek'] = df.date.dt.dayofweek.astype('int8')
df['dayofyear'] = df.date.dt.dayofyear.astype('int16')
df['weekofyear'] = df.date.dt.isocalendar().week.astype('int8')
df['quarter'] = df.date.dt.quarter.astype('int8')

# Binary features
df['is_weekend'] = (df.dayofweek >= 5).astype('int8')
df['is_month_start'] = df.date.dt.is_month_start.astype('int8')
df['is_month_end'] = df.date.dt.is_month_end.astype('int8')

# Wage day (15th and month end - Ecuador public sector)
df['is_wageday'] = ((df.day == 15) | (df.is_month_end == 1)).astype('int8')

# Work day
df['is_workday'] = ((df.is_weekend == 0) &
                    (df.is_national_holiday == 0) &
                    (df.is_regional_holiday == 0) &
                    (df.is_local_holiday == 0)).astype('int8')

# Season (Ecuador - Southern Hemisphere)
df['season'] = 0  # Summer: Dec-Feb
df.loc[df.month.isin([3,4,5]), 'season'] = 1  # Autumn
df.loc[df.month.isin([6,7,8]), 'season'] = 2  # Winter
df.loc[df.month.isin([9,10,11]), 'season'] = 3  # Spring
df['season'] = df['season'].astype('int8')

# Encode categoricals
df['family_encoded'] = df['family'].astype('category').cat.codes.astype('int8')
df['city_encoded'] = df['city'].astype('category').cat.codes.astype('int8')
df['state_encoded'] = df['state'].astype('category').cat.codes.astype('int8')
df['type_encoded'] = df['type'].astype('category').cat.codes.astype('int8')

# ============================================================================
# 8. LAG & ROLLING FEATURES (Only for training data)
# ============================================================================
print("\n[8/10] Creating lag and rolling features...")

df = df.sort_values(['store_nbr', 'family', 'date']).reset_index(drop=True)

# Lag features based on PACF analysis (must be >= 16 since test is 15 days)
for lag in [16, 20, 30, 45, 60]:
    df[f'sales_lag_{lag}'] = df.groupby(['store_nbr', 'family'])['sales'].shift(lag)

# Yearly lag (very important for seasonality)
df['sales_lag_365'] = df.groupby(['store_nbr', 'family'])['sales'].shift(365)

# Rolling means (SMA)
for window in [7, 14, 30, 60]:
    df[f'sales_sma_{window}_lag16'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(16).rolling(window, min_periods=1).mean()
    )

# Rolling std
df['sales_std_30_lag16'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
    lambda x: x.shift(16).rolling(30, min_periods=1).std()
)

# Exponential Moving Average
for alpha in [0.95, 0.8, 0.5]:
    alpha_str = str(alpha).replace('.', '')
    df[f'sales_ewm_{alpha_str}_lag16'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(16).ewm(alpha=alpha, min_periods=1).mean()
    )

# Promotion rolling
df['promo_rolling_7'] = df.groupby(['store_nbr', 'family'])['onpromotion'].transform(
    lambda x: x.rolling(7, min_periods=1).mean()
)

# Transactions merge and features (only for training period)
trans_agg = transactions.groupby(['store_nbr', 'date']).transactions.sum().reset_index()
df = df.merge(trans_agg, on=['store_nbr', 'date'], how='left')

# For test period, use average transactions
store_avg_trans = transactions.groupby('store_nbr').transactions.mean()
df['transactions'] = df['transactions'].fillna(df['store_nbr'].map(store_avg_trans))

print(f"  Total features: {len(df.columns)}")

# ============================================================================
# 9. PREPARE FOR TRAINING
# ============================================================================
print("\n[9/10] Preparing training data...")

# Split back
train_df = df[df.is_train == 1].copy()
test_df = df[df.is_train == 0].copy()

# Use last 2 years for training (more recent = more relevant)
train_df = train_df[train_df.date >= '2015-08-01']

# Drop rows with NaN in lag features
train_df = train_df.dropna(subset=['sales_lag_16'])

feature_cols = [
    'store_nbr', 'onpromotion', 'cluster',
    'year', 'month', 'day', 'dayofweek', 'dayofyear', 'weekofyear', 'quarter',
    'is_weekend', 'is_month_start', 'is_month_end', 'is_wageday', 'is_workday', 'season',
    'oil_price', 'oil_ma7', 'oil_ma30', 'oil_price_high',
    'is_national_holiday', 'is_regional_holiday', 'is_local_holiday',
    'is_earthquake', 'is_futbol',
    'family_encoded', 'city_encoded', 'state_encoded', 'type_encoded',
    'sales_lag_16', 'sales_lag_20', 'sales_lag_30', 'sales_lag_45', 'sales_lag_60',
    'sales_lag_365',
    'sales_sma_7_lag16', 'sales_sma_14_lag16', 'sales_sma_30_lag16', 'sales_sma_60_lag16',
    'sales_std_30_lag16',
    'sales_ewm_095_lag16', 'sales_ewm_08_lag16', 'sales_ewm_05_lag16',
    'promo_rolling_7', 'transactions'
]

# Handle missing lag_365 (first year doesn't have it)
train_df['sales_lag_365'] = train_df['sales_lag_365'].fillna(train_df['sales_sma_30_lag16'])
test_df['sales_lag_365'] = test_df['sales_lag_365'].fillna(test_df['sales_sma_30_lag16'])

# Fill remaining NaN
for col in feature_cols:
    if col in train_df.columns:
        train_df[col] = train_df[col].fillna(0)
    if col in test_df.columns:
        test_df[col] = test_df[col].fillna(0)

X_train = train_df[feature_cols]
y_train = train_df['sales']

X_test = test_df[feature_cols]

print(f"  Training samples: {len(X_train)}")
print(f"  Test samples: {len(X_test)}")
print(f"  Features: {len(feature_cols)}")

# ============================================================================
# 10. TRAIN ENSEMBLE MODEL
# ============================================================================
print("\n[10/10] Training ensemble model...")

# Transform target (log1p for RMSLE)
y_train_log = np.log1p(y_train)

# Time series CV
tscv = TimeSeriesSplit(n_splits=5)

# ---- LightGBM with Tweedie ----
print("\n  Training LightGBM...")
lgb_params = {
    'objective': 'tweedie',
    'tweedie_variance_power': 1.1,
    'metric': 'rmse',
    'boosting_type': 'gbdt',
    'num_leaves': 255,
    'learning_rate': 0.03,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'min_child_samples': 50,
    'verbose': -1,
    'n_jobs': -1,
    'seed': 42
}

lgb_models = []
lgb_scores = []

for fold, (train_idx, val_idx) in enumerate(tscv.split(X_train)):
    X_tr, X_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
    y_tr, y_val = y_train.iloc[train_idx], y_train.iloc[val_idx]

    train_data = lgb.Dataset(X_tr, label=y_tr)
    val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)

    model = lgb.train(
        lgb_params, train_data,
        num_boost_round=2000,
        valid_sets=[val_data],
        callbacks=[lgb.early_stopping(100), lgb.log_evaluation(200)]
    )

    val_pred = model.predict(X_val)
    val_pred = np.maximum(val_pred, 0)

    rmsle = np.sqrt(mean_squared_log_error(y_val + 1, val_pred + 1))
    lgb_scores.append(rmsle)
    lgb_models.append(model)
    print(f"    Fold {fold+1} RMSLE: {rmsle:.5f}")

print(f"  LightGBM Mean CV: {np.mean(lgb_scores):.5f}")

# ---- XGBoost ----
print("\n  Training XGBoost...")
xgb_params = {
    'objective': 'reg:tweedie',
    'tweedie_variance_power': 1.1,
    'eval_metric': 'rmse',
    'max_depth': 8,
    'learning_rate': 0.03,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'min_child_weight': 50,
    'seed': 42,
    'n_jobs': -1
}

xgb_models = []
xgb_scores = []

for fold, (train_idx, val_idx) in enumerate(tscv.split(X_train)):
    X_tr, X_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
    y_tr, y_val = y_train.iloc[train_idx], y_train.iloc[val_idx]

    dtrain = xgb.DMatrix(X_tr, label=y_tr)
    dval = xgb.DMatrix(X_val, label=y_val)

    model = xgb.train(
        xgb_params, dtrain,
        num_boost_round=2000,
        evals=[(dval, 'val')],
        early_stopping_rounds=100,
        verbose_eval=200
    )

    val_pred = model.predict(dval)
    val_pred = np.maximum(val_pred, 0)

    rmsle = np.sqrt(mean_squared_log_error(y_val + 1, val_pred + 1))
    xgb_scores.append(rmsle)
    xgb_models.append(model)
    print(f"    Fold {fold+1} RMSLE: {rmsle:.5f}")

print(f"  XGBoost Mean CV: {np.mean(xgb_scores):.5f}")

# ============================================================================
# 11. MAKE PREDICTIONS
# ============================================================================
print("\n[11] Making ensemble predictions...")

# LightGBM predictions
lgb_pred = np.zeros(len(X_test))
for model in lgb_models:
    lgb_pred += model.predict(X_test) / len(lgb_models)

# XGBoost predictions
xgb_pred = np.zeros(len(X_test))
dtest = xgb.DMatrix(X_test)
for model in xgb_models:
    xgb_pred += model.predict(dtest) / len(xgb_models)

# Ensemble (weighted average)
# Weight based on CV scores
lgb_weight = 1 / np.mean(lgb_scores)
xgb_weight = 1 / np.mean(xgb_scores)
total_weight = lgb_weight + xgb_weight

predictions = (lgb_pred * lgb_weight + xgb_pred * xgb_weight) / total_weight
predictions = np.maximum(predictions, 0)

# Create submission
submission = pd.DataFrame({
    'id': test_df['id'].values,
    'sales': predictions
})

# Add zero predictions
if len(zero_prediction) > 0:
    zero_sub = test.merge(zero_prediction[['store_nbr', 'family', 'date', 'sales']],
                          on=['store_nbr', 'family', 'date'], how='inner')
    if len(zero_sub) > 0:
        # Update submission with zero predictions
        for _, row in zero_sub.iterrows():
            submission.loc[submission['id'] == row['id'], 'sales'] = 0

submission.to_csv('submission.csv', index=False)

print(f"\n  Submission saved: {len(submission)} predictions")
print(f"  Sales range: {predictions.min():.2f} - {predictions.max():.2f}")
print(f"  Sales mean: {predictions.mean():.2f}")

# Feature importance
print("\n[Feature Importance - Top 15]")
importance = pd.DataFrame({
    'feature': feature_cols,
    'lgb_importance': lgb_models[0].feature_importance(),
    'xgb_importance': xgb_models[0].get_score(importance_type='gain').get(f, 0) if (f := feature_cols[0]) else 0
})
importance = importance.sort_values('lgb_importance', ascending=False)
print(importance[['feature', 'lgb_importance']].head(15).to_string(index=False))

print("\n" + "="*70)
print(f"   LightGBM CV: {np.mean(lgb_scores):.5f}")
print(f"   XGBoost CV:  {np.mean(xgb_scores):.5f}")
print(f"   Ensemble ready - submission.csv created!")
print("="*70)
