"""
Store Sales V7 - FIXED VERSION
==============================
Problema V1-V6: CV buono ma LB pessimo (0.59)
Causa: Lag features per test set calcolate male

Fix:
1. Recursive forecasting per lag features
2. Niente zero forecasting aggressivo
3. Validazione più realistica (gap tra train e val)
"""

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import TimeSeriesSplit
import warnings
warnings.filterwarnings('ignore')

print("="*70)
print("   🔧 STORE SALES V7 - FIXED VERSION")
print("="*70)

# =============================================================================
# LOAD DATA
# =============================================================================
print("\n[1/7] Loading data...")
PATH = "/kaggle/input/store-sales-time-series-forecasting"

train = pd.read_csv(f"{PATH}/train.csv", parse_dates=['date'])
test = pd.read_csv(f"{PATH}/test.csv", parse_dates=['date'])
stores = pd.read_csv(f"{PATH}/stores.csv")
oil = pd.read_csv(f"{PATH}/oil.csv", parse_dates=['date'])
holidays = pd.read_csv(f"{PATH}/holidays_events.csv", parse_dates=['date'])

print(f"  Train: {train.shape}, Test: {test.shape}")
print(f"  Train dates: {train['date'].min()} to {train['date'].max()}")
print(f"  Test dates: {test['date'].min()} to {test['date'].max()}")

# =============================================================================
# MERGE EXTERNAL DATA
# =============================================================================
print("\n[2/7] Merging external data...")

# Oil
oil = oil.rename(columns={'dcoilwtico': 'oil_price'})
oil['oil_price'] = oil['oil_price'].interpolate(method='linear').ffill().bfill()

# Holidays - only national
holidays_nat = holidays[(holidays['locale'] == 'National') &
                        (holidays['transferred'] == False)][['date']].drop_duplicates()
holidays_nat['is_holiday'] = 1

# Combine train and test
train['is_train'] = 1
test['is_train'] = 0
test['sales'] = np.nan

df = pd.concat([train, test], ignore_index=True)
df = df.merge(stores, on='store_nbr', how='left')
df = df.merge(oil, on='date', how='left')
df = df.merge(holidays_nat, on='date', how='left')

df['oil_price'] = df['oil_price'].ffill().bfill()
df['is_holiday'] = df['is_holiday'].fillna(0)

# =============================================================================
# FEATURE ENGINEERING (NO LAGS YET)
# =============================================================================
print("\n[3/7] Feature engineering...")

# Date features
df['year'] = df['date'].dt.year
df['month'] = df['date'].dt.month
df['day'] = df['date'].dt.day
df['dayofweek'] = df['date'].dt.dayofweek
df['dayofyear'] = df['date'].dt.dayofyear
df['weekofyear'] = df['date'].dt.isocalendar().week.astype(int)

# Boolean
df['is_weekend'] = (df['dayofweek'] >= 5).astype(int)
df['is_month_start'] = df['date'].dt.is_month_start.astype(int)
df['is_month_end'] = df['date'].dt.is_month_end.astype(int)
df['is_payday'] = ((df['day'] == 15) | (df['day'] >= 28)).astype(int)

# Encode categoricals
df['family_encoded'] = df['family'].astype('category').cat.codes
df['city_encoded'] = df['city'].astype('category').cat.codes
df['state_encoded'] = df['state'].astype('category').cat.codes
df['type_encoded'] = df['type'].astype('category').cat.codes

# Store-family key
df['store_family'] = df['store_nbr'].astype(str) + '_' + df['family']

print(f"  Features created")

# =============================================================================
# CREATE LAG FEATURES PROPERLY
# =============================================================================
print("\n[4/7] Creating lag features (proper method)...")

df = df.sort_values(['store_nbr', 'family', 'date']).reset_index(drop=True)

# Create lags ONLY from training data
for lag in [7, 14, 28]:
    df[f'sales_lag_{lag}'] = df.groupby(['store_nbr', 'family'])['sales'].shift(lag)

# Rolling features - also shift by 1 to avoid leakage
for window in [7, 14, 28]:
    df[f'sales_rolling_{window}'] = df.groupby(['store_nbr', 'family'])['sales'].transform(
        lambda x: x.shift(1).rolling(window, min_periods=1).mean()
    )

print(f"  Lag features created")

# =============================================================================
# PREPARE TRAINING DATA
# =============================================================================
print("\n[5/7] Preparing training data...")

# Features that don't depend on sales (can use for test)
base_features = [
    'store_nbr', 'onpromotion', 'cluster',
    'year', 'month', 'day', 'dayofweek', 'dayofyear', 'weekofyear',
    'is_weekend', 'is_month_start', 'is_month_end', 'is_payday',
    'oil_price', 'is_holiday',
    'family_encoded', 'city_encoded', 'state_encoded', 'type_encoded'
]

# Lag features (will need special handling for test)
lag_features = [
    'sales_lag_7', 'sales_lag_14', 'sales_lag_28',
    'sales_rolling_7', 'sales_rolling_14', 'sales_rolling_28'
]

all_features = base_features + lag_features

# Training: use data from 2016 onwards (more recent, after lag warmup)
train_df = df[(df['is_train'] == 1) & (df['date'] >= '2016-01-01')].copy()
train_df = train_df.dropna(subset=['sales'] + lag_features)

X_train = train_df[all_features]
y_train = train_df['sales']
y_train_log = np.log1p(y_train)

print(f"  Training samples: {len(X_train)}")
print(f"  Features: {len(all_features)}")

# =============================================================================
# TRAIN MODEL WITH REALISTIC VALIDATION
# =============================================================================
print("\n[6/7] Training LightGBM...")

# Use gap in validation to simulate real forecasting
# Last 16 days for validation (same as test period)
val_start = train_df['date'].max() - pd.Timedelta(days=15)
train_mask = train_df['date'] < val_start
val_mask = train_df['date'] >= val_start

X_tr = X_train[train_mask.values]
y_tr = y_train_log[train_mask.values]
X_val = X_train[val_mask.values]
y_val = y_train_log[val_mask.values]

print(f"  Train: {len(X_tr)}, Val: {len(X_val)}")

params = {
    'objective': 'regression',
    'metric': 'rmse',
    'boosting_type': 'gbdt',
    'learning_rate': 0.03,
    'num_leaves': 127,
    'max_depth': 10,
    'min_child_samples': 30,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'lambda_l1': 0.1,
    'lambda_l2': 0.1,
    'verbose': -1,
    'n_jobs': -1,
    'seed': 42
}

train_data = lgb.Dataset(X_tr, label=y_tr)
val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)

model = lgb.train(
    params, train_data,
    num_boost_round=2000,
    valid_sets=[val_data],
    callbacks=[lgb.early_stopping(100), lgb.log_evaluation(100)]
)

# Validation score
val_pred = np.expm1(model.predict(X_val))
val_pred = np.maximum(val_pred, 0)
val_actual = np.expm1(y_val)

rmsle = np.sqrt(np.mean((np.log1p(val_pred) - np.log1p(val_actual))**2))
print(f"\n  Validation RMSLE: {rmsle:.5f}")

# =============================================================================
# RECURSIVE PREDICTION FOR TEST
# =============================================================================
print("\n[7/7] Recursive prediction for test set...")

# Get test data
test_df = df[df['is_train'] == 0].copy()

# Get last known sales for each store-family from training
last_train_date = train['date'].max()
print(f"  Last train date: {last_train_date}")

# Create a combined df for recursive prediction
# We need to predict day by day, updating lag features

# Get the last 28 days of training data for lag calculation
recent_train = df[(df['is_train'] == 1) &
                  (df['date'] > last_train_date - pd.Timedelta(days=28))].copy()

# Combine with test for recursive prediction
pred_df = pd.concat([recent_train, test_df], ignore_index=True)
pred_df = pred_df.sort_values(['store_nbr', 'family', 'date']).reset_index(drop=True)

# Get unique test dates
test_dates = sorted(test_df['date'].unique())
print(f"  Test dates to predict: {len(test_dates)}")

# Predict day by day
all_predictions = []

for i, pred_date in enumerate(test_dates):
    # Get rows for this date
    date_mask = pred_df['date'] == pred_date
    rows_to_predict = pred_df[date_mask].copy()

    # Calculate lag features from current pred_df (includes previous predictions)
    for idx in rows_to_predict.index:
        store = pred_df.loc[idx, 'store_nbr']
        family = pred_df.loc[idx, 'family']

        # Get historical sales for this store-family
        hist = pred_df[(pred_df['store_nbr'] == store) &
                       (pred_df['family'] == family) &
                       (pred_df['date'] < pred_date)]['sales'].values

        if len(hist) >= 7:
            pred_df.loc[idx, 'sales_lag_7'] = hist[-7]
            pred_df.loc[idx, 'sales_rolling_7'] = np.mean(hist[-7:])
        if len(hist) >= 14:
            pred_df.loc[idx, 'sales_lag_14'] = hist[-14]
            pred_df.loc[idx, 'sales_rolling_14'] = np.mean(hist[-14:])
        if len(hist) >= 28:
            pred_df.loc[idx, 'sales_lag_28'] = hist[-28]
            pred_df.loc[idx, 'sales_rolling_28'] = np.mean(hist[-28:])

    # Get features for prediction
    X_pred = pred_df.loc[date_mask, all_features].fillna(0)

    # Predict
    preds = np.expm1(model.predict(X_pred))
    preds = np.maximum(preds, 0)

    # Store predictions back in pred_df for next iteration
    pred_df.loc[date_mask, 'sales'] = preds

    # Save predictions with IDs
    pred_ids = pred_df.loc[date_mask, 'id'].values
    for pid, psale in zip(pred_ids, preds):
        all_predictions.append({'id': int(pid), 'sales': psale})

    if (i + 1) % 4 == 0:
        print(f"  Predicted {i + 1}/{len(test_dates)} days")

print(f"  Total predictions: {len(all_predictions)}")

# Create submission
submission = pd.DataFrame(all_predictions)
submission = submission.sort_values('id').reset_index(drop=True)
submission.to_csv('submission.csv', index=False)

print(f"\n  Submission saved: {len(submission)} rows")
print(f"  Sales range: {submission['sales'].min():.2f} - {submission['sales'].max():.2f}")
print(f"  Sales mean: {submission['sales'].mean():.2f}")

# Feature importance
print("\n[Top 10 Features]")
imp = pd.DataFrame({
    'feature': all_features,
    'importance': model.feature_importance()
}).sort_values('importance', ascending=False)
print(imp.head(10).to_string(index=False))

print("\n" + "="*70)
print(f"   🔧 V7 FIXED - Val RMSLE: {rmsle:.5f}")
print("="*70)
