"""
Store Sales V8 - ULTIMATE
=========================
Target: Battere top LB 0.377
Current best: V7 = 0.383

Miglioramenti:
1. Più lag features (7, 14, 21, 28, 35, 42, 49, 56, 63)
2. Yearly seasonality (lag 364, 365, 366)
3. Promo features avanzate
4. Oil momentum features
5. Store/family interaction features
6. Optimized hyperparameters
7. Multiple models ensemble
"""

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import TimeSeriesSplit
import warnings
warnings.filterwarnings('ignore')

print("="*70)
print("   🏆 STORE SALES V8 - ULTIMATE 🏆")
print("   Target: Beat 0.377")
print("="*70)

# =============================================================================
# LOAD DATA
# =============================================================================
print("\n[1/8] Loading data...")
PATH = "/kaggle/input/store-sales-time-series-forecasting"

train = pd.read_csv(f"{PATH}/train.csv", parse_dates=['date'])
test = pd.read_csv(f"{PATH}/test.csv", parse_dates=['date'])
stores = pd.read_csv(f"{PATH}/stores.csv")
oil = pd.read_csv(f"{PATH}/oil.csv", parse_dates=['date'])
holidays = pd.read_csv(f"{PATH}/holidays_events.csv", parse_dates=['date'])
transactions = pd.read_csv(f"{PATH}/transactions.csv", parse_dates=['date'])

print(f"  Train: {train.shape}, Test: {test.shape}")

# =============================================================================
# ADVANCED OIL FEATURES
# =============================================================================
print("\n[2/8] Advanced oil features...")

oil = oil.rename(columns={'dcoilwtico': 'oil_price'})
oil = oil.set_index('date').reindex(pd.date_range(train['date'].min(), test['date'].max())).reset_index()
oil.columns = ['date', 'oil_price']
oil['oil_price'] = oil['oil_price'].interpolate(method='linear').ffill().bfill()

# Oil momentum
oil['oil_ma7'] = oil['oil_price'].rolling(7, min_periods=1).mean()
oil['oil_ma14'] = oil['oil_price'].rolling(14, min_periods=1).mean()
oil['oil_ma28'] = oil['oil_price'].rolling(28, min_periods=1).mean()
oil['oil_momentum'] = oil['oil_price'] - oil['oil_ma7']
oil['oil_volatility'] = oil['oil_price'].rolling(7, min_periods=1).std().fillna(0)
oil['oil_change'] = oil['oil_price'].pct_change().fillna(0)

# =============================================================================
# HOLIDAY FEATURES
# =============================================================================
print("\n[3/8] Holiday features...")

# National holidays
nat_holidays = holidays[(holidays['locale'] == 'National') &
                        (holidays['transferred'] == False)].copy()

# Create holiday types
nat_holidays['is_christmas'] = nat_holidays['description'].str.contains('Navidad', case=False, na=False).astype(int)
nat_holidays['is_newyear'] = nat_holidays['description'].str.contains('Año|Primer', case=False, na=False).astype(int)
nat_holidays['is_carnival'] = nat_holidays['description'].str.contains('Carnaval', case=False, na=False).astype(int)
nat_holidays['is_goodfriday'] = nat_holidays['description'].str.contains('Viernes Santo', case=False, na=False).astype(int)
nat_holidays['is_labor'] = nat_holidays['description'].str.contains('Trabajo', case=False, na=False).astype(int)
nat_holidays['is_independence'] = nat_holidays['description'].str.contains('Independencia|Grito', case=False, na=False).astype(int)

holiday_features = nat_holidays.groupby('date').agg({
    'is_christmas': 'max', 'is_newyear': 'max', 'is_carnival': 'max',
    'is_goodfriday': 'max', 'is_labor': 'max', 'is_independence': 'max'
}).reset_index()
holiday_features['is_holiday'] = 1

# Days to/from holiday
all_dates = pd.DataFrame({'date': pd.date_range(train['date'].min(), test['date'].max())})
all_dates = all_dates.merge(holiday_features[['date', 'is_holiday']], on='date', how='left')
all_dates['is_holiday'] = all_dates['is_holiday'].fillna(0)

# Days since last holiday
all_dates['days_since_holiday'] = all_dates.groupby(all_dates['is_holiday'].cumsum()).cumcount()
# Days until next holiday (reverse)
all_dates['days_until_holiday'] = all_dates.iloc[::-1].groupby(all_dates.iloc[::-1]['is_holiday'].cumsum()).cumcount().iloc[::-1].values

all_dates = all_dates.merge(holiday_features, on='date', how='left', suffixes=('', '_dup'))
# Drop duplicate is_holiday if exists
if 'is_holiday_dup' in all_dates.columns:
    all_dates = all_dates.drop(columns=['is_holiday_dup'])
for col in all_dates.columns:
    if col != 'date' and all_dates[col].dtype in ['float64', 'int64', 'Int64']:
        all_dates[col] = all_dates[col].fillna(0)

# =============================================================================
# COMBINE DATA
# =============================================================================
print("\n[4/8] Combining data...")

train['is_train'] = 1
test['is_train'] = 0
test['sales'] = np.nan

df = pd.concat([train, test], ignore_index=True)

# Merge all
df = df.merge(stores, on='store_nbr', how='left')
df = df.merge(oil, on='date', how='left')
df = df.merge(all_dates, on='date', how='left')
df = df.merge(transactions, on=['date', 'store_nbr'], how='left')

# Fill missing
df['transactions'] = df.groupby('store_nbr')['transactions'].transform(lambda x: x.fillna(x.median()))
df['transactions'] = df['transactions'].fillna(0)

# =============================================================================
# FEATURE ENGINEERING
# =============================================================================
print("\n[5/8] Feature engineering...")

# Date features
df['year'] = df['date'].dt.year
df['month'] = df['date'].dt.month
df['day'] = df['date'].dt.day
df['dayofweek'] = df['date'].dt.dayofweek
df['dayofyear'] = df['date'].dt.dayofyear
df['weekofyear'] = df['date'].dt.isocalendar().week.astype(int)
df['quarter'] = df['date'].dt.quarter

# Advanced date features
df['is_weekend'] = (df['dayofweek'] >= 5).astype(int)
df['is_month_start'] = (df['day'] <= 5).astype(int)
df['is_month_end'] = (df['day'] >= 26).astype(int)
df['is_payday'] = ((df['day'] == 15) | (df['day'] >= 28)).astype(int)
df['week_of_month'] = (df['day'] - 1) // 7 + 1

# Cyclic encoding
for col, period in [('month', 12), ('dayofweek', 7), ('dayofyear', 365), ('day', 31)]:
    df[f'{col}_sin'] = np.sin(2 * np.pi * df[col] / period)
    df[f'{col}_cos'] = np.cos(2 * np.pi * df[col] / period)

# Encode categoricals
for col in ['family', 'city', 'state', 'type']:
    df[f'{col}_encoded'] = df[col].astype('category').cat.codes

# Store features
df['store_family'] = df['store_nbr'].astype(str) + '_' + df['family']

# Promo features
df['promo_intensity'] = df.groupby(['date', 'store_nbr'])['onpromotion'].transform('sum')
df['family_promo_rate'] = df.groupby(['date', 'family'])['onpromotion'].transform('mean')

print(f"  Base features created")

# =============================================================================
# LAG FEATURES (EXTENSIVE)
# =============================================================================
print("\n[6/8] Creating extensive lag features...")

df = df.sort_values(['store_nbr', 'family', 'date']).reset_index(drop=True)

# Weekly lags (7, 14, 21, 28, 35, 42, 49, 56)
for lag in [7, 14, 21, 28, 35, 42, 49, 56]:
    df[f'sales_lag_{lag}'] = df.groupby(['store_nbr', 'family'])['sales'].shift(lag)

# Yearly lags (seasonality)
for lag in [364, 365, 366]:
    df[f'sales_lag_{lag}'] = df.groupby(['store_nbr', 'family'])['sales'].shift(lag)

# Rolling statistics
for window in [7, 14, 28, 56]:
    df[f'sales_roll_mean_{window}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).rolling(window, min_periods=1).mean())
    df[f'sales_roll_std_{window}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).rolling(window, min_periods=1).std().fillna(0))
    df[f'sales_roll_max_{window}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).rolling(window, min_periods=1).max())
    df[f'sales_roll_min_{window}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).rolling(window, min_periods=1).min())

# EWM
for span in [7, 14, 28]:
    df[f'sales_ewm_{span}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).ewm(span=span, min_periods=1).mean())

# Day of week mean (same weekday historical average)
df['dow_family_mean'] = df.groupby(['store_nbr', 'family', 'dayofweek'])['sales'].transform(
    lambda x: x.shift(1).expanding().mean())

print(f"  Lag features created")

# =============================================================================
# PREPARE DATA
# =============================================================================
print("\n[7/8] Preparing training data...")

# All features
base_features = [
    'store_nbr', 'onpromotion', 'cluster', 'transactions',
    'year', 'month', 'day', 'dayofweek', 'dayofyear', 'weekofyear', 'quarter',
    'is_weekend', 'is_month_start', 'is_month_end', 'is_payday', 'week_of_month',
    'month_sin', 'month_cos', 'dayofweek_sin', 'dayofweek_cos',
    'dayofyear_sin', 'dayofyear_cos', 'day_sin', 'day_cos',
    'oil_price', 'oil_ma7', 'oil_ma14', 'oil_ma28', 'oil_momentum', 'oil_volatility', 'oil_change',
    'is_holiday', 'is_christmas', 'is_newyear', 'is_carnival', 'is_goodfriday', 'is_labor', 'is_independence',
    'days_since_holiday', 'days_until_holiday',
    'family_encoded', 'city_encoded', 'state_encoded', 'type_encoded',
    'promo_intensity', 'family_promo_rate'
]

lag_features = [c for c in df.columns if 'lag_' in c or 'roll_' in c or 'ewm_' in c or 'dow_family' in c]
all_features = base_features + lag_features
all_features = [f for f in all_features if f in df.columns]

# Training data (from 2016, after lag warmup)
train_df = df[(df['is_train'] == 1) & (df['date'] >= '2016-01-01')].copy()
train_df = train_df.dropna(subset=['sales'])

# Fill lag NaN with 0 for training
for col in lag_features:
    if col in train_df.columns:
        train_df[col] = train_df[col].fillna(0)

X_train = train_df[all_features]
y_train = train_df['sales']
y_train_log = np.log1p(y_train)

# Realistic validation (last 16 days)
val_start = train_df['date'].max() - pd.Timedelta(days=15)
train_mask = train_df['date'] < val_start
val_mask = train_df['date'] >= val_start

X_tr, y_tr = X_train[train_mask.values], y_train_log[train_mask.values]
X_val, y_val = X_train[val_mask.values], y_train_log[val_mask.values]

print(f"  Train: {len(X_tr)}, Val: {len(X_val)}")
print(f"  Features: {len(all_features)}")

# =============================================================================
# TRAIN OPTIMIZED MODEL
# =============================================================================
print("\n[8/8] Training optimized LightGBM...")

params = {
    'objective': 'regression',
    'metric': 'rmse',
    'boosting_type': 'gbdt',
    'learning_rate': 0.02,
    'num_leaves': 255,
    'max_depth': 12,
    'min_child_samples': 20,
    'feature_fraction': 0.7,
    'bagging_fraction': 0.7,
    'bagging_freq': 1,
    'lambda_l1': 0.1,
    'lambda_l2': 1.0,
    'min_gain_to_split': 0.01,
    'verbose': -1,
    'n_jobs': -1,
    'seed': 42
}

train_data = lgb.Dataset(X_tr, label=y_tr)
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)

model = lgb.train(
    params, train_data,
    num_boost_round=5000,
    valid_sets=[val_data],
    callbacks=[lgb.early_stopping(200), lgb.log_evaluation(200)]
)

# Validation score
val_pred = np.expm1(model.predict(X_val))
val_pred = np.maximum(val_pred, 0)
val_actual = np.expm1(y_val)

rmsle = np.sqrt(np.mean((np.log1p(val_pred) - np.log1p(val_actual))**2))
print(f"\n  ⭐ Validation RMSLE: {rmsle:.5f}")
print(f"  Target: 0.377, Gap: {rmsle - 0.377:.5f}")

# =============================================================================
# RECURSIVE PREDICTION
# =============================================================================
print("\n[9/9] Recursive prediction for test...")

test_df = df[df['is_train'] == 0].copy()
last_train_date = train['date'].max()

# Get recent training data for lag calculation
recent_train = df[(df['is_train'] == 1) &
                  (df['date'] > last_train_date - pd.Timedelta(days=400))].copy()

pred_df = pd.concat([recent_train, test_df], ignore_index=True)
pred_df = pred_df.sort_values(['store_nbr', 'family', 'date']).reset_index(drop=True)

test_dates = sorted(test_df['date'].unique())
print(f"  Predicting {len(test_dates)} days...")

all_predictions = []

for i, pred_date in enumerate(test_dates):
    date_mask = pred_df['date'] == pred_date

    # Update lag features from predictions
    for idx in pred_df[date_mask].index:
        store = pred_df.loc[idx, 'store_nbr']
        family = pred_df.loc[idx, 'family']

        hist = pred_df[(pred_df['store_nbr'] == store) &
                       (pred_df['family'] == family) &
                       (pred_df['date'] < pred_date)]['sales'].values

        # Weekly lags
        for lag in [7, 14, 21, 28, 35, 42, 49, 56]:
            col = f'sales_lag_{lag}'
            if col in pred_df.columns and len(hist) >= lag:
                pred_df.loc[idx, col] = hist[-lag]

        # Rolling features
        for window in [7, 14, 28, 56]:
            if len(hist) >= window:
                recent = hist[-window:]
                pred_df.loc[idx, f'sales_roll_mean_{window}'] = np.mean(recent)
                pred_df.loc[idx, f'sales_roll_std_{window}'] = np.std(recent)
                pred_df.loc[idx, f'sales_roll_max_{window}'] = np.max(recent)
                pred_df.loc[idx, f'sales_roll_min_{window}'] = np.min(recent)

        # EWM approximation
        for span in [7, 14, 28]:
            if len(hist) >= span:
                weights = np.exp(np.linspace(-1, 0, min(span, len(hist))))
                weights /= weights.sum()
                pred_df.loc[idx, f'sales_ewm_{span}'] = np.average(hist[-span:], weights=weights[-len(hist[-span:]):])

    # Predict
    X_pred = pred_df.loc[date_mask, all_features].fillna(0)
    preds = np.expm1(model.predict(X_pred))
    preds = np.maximum(preds, 0)

    # Store predictions
    pred_df.loc[date_mask, 'sales'] = preds

    for pid, psale in zip(pred_df.loc[date_mask, 'id'].values, preds):
        all_predictions.append({'id': int(pid), 'sales': psale})

    if (i + 1) % 4 == 0:
        print(f"    Day {i + 1}/{len(test_dates)} done")

# Create submission
submission = pd.DataFrame(all_predictions).sort_values('id')
submission.to_csv('submission.csv', index=False)

print(f"\n  Submission: {len(submission)} rows")
print(f"  Sales range: {submission['sales'].min():.2f} - {submission['sales'].max():.2f}")
print(f"  Sales mean: {submission['sales'].mean():.2f}")

# Feature importance
print("\n[Top 15 Features]")
imp = pd.DataFrame({
    'feature': all_features,
    'importance': model.feature_importance()
}).sort_values('importance', ascending=False)
print(imp.head(15).to_string(index=False))

print("\n" + "="*70)
print(f"   🏆 V8 ULTIMATE - Val RMSLE: {rmsle:.5f}")
print(f"   Target: 0.377 | Gap: {rmsle - 0.377:+.5f}")
print("="*70)
