#!/usr/bin/env python3
"""
================================================================================
   VESUVIUS WINNER AGENT
   Autonomous Kaggle competition agent for Vesuvius Challenge

   Target: Beat 0.575 (current #1)
   Prize: $100,000
================================================================================
"""

import os
import sys
import json
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Tuple
import argparse

# =============================================================================
# CONFIGURATION
# =============================================================================

COMPETITION = "vesuvius-challenge-surface-detection"
TARGET_SCORE = 0.575
PRIZE = 100000

# Paths
WORKSPACE = Path("/workspace")
DATA_DIR = WORKSPACE / "data"
MODELS_DIR = WORKSPACE / "models"
PREDICTIONS_DIR = WORKSPACE / "predictions"
SUBMISSIONS_DIR = WORKSPACE / "submissions"
LOGS_DIR = WORKSPACE / "logs"

# nnUNet paths
NNUNET_RAW = WORKSPACE / "nnUNet_raw"
NNUNET_PREPROCESSED = WORKSPACE / "nnUNet_preprocessed"
NNUNET_RESULTS = WORKSPACE / "nnUNet_results"

# Kaggle limits
MAX_SUBMISSIONS_PER_DAY = 5
MAX_GPU_HOURS_WEEK = 30
MAX_KERNEL_RUNTIME_HOURS = 9


class VesuviusAgent:
    """Autonomous agent for Vesuvius Challenge."""

    def __init__(self, verbose: bool = True):
        self.verbose = verbose
        self.log_file = LOGS_DIR / f"agent_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
        self.submissions_today = 0
        self.best_score = None

        # Create directories
        for d in [DATA_DIR, MODELS_DIR, PREDICTIONS_DIR, SUBMISSIONS_DIR, LOGS_DIR,
                  NNUNET_RAW, NNUNET_PREPROCESSED, NNUNET_RESULTS]:
            d.mkdir(parents=True, exist_ok=True)

        # Set nnUNet environment
        os.environ["nnUNet_raw"] = str(NNUNET_RAW)
        os.environ["nnUNet_preprocessed"] = str(NNUNET_PREPROCESSED)
        os.environ["nnUNet_results"] = str(NNUNET_RESULTS)
        os.environ["nnUNet_compile"] = "false"

    def log(self, message: str, level: str = "INFO"):
        """Log message to console and file."""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_line = f"[{timestamp}] [{level}] {message}"

        if self.verbose:
            print(log_line)

        with open(self.log_file, "a") as f:
            f.write(log_line + "\n")

    def run_command(self, cmd: str, timeout: int = 3600) -> Tuple[bool, str]:
        """Run shell command with logging."""
        self.log(f"Running: {cmd}")
        try:
            result = subprocess.run(
                cmd, shell=True, capture_output=True, text=True, timeout=timeout
            )
            output = result.stdout + result.stderr
            success = result.returncode == 0
            if not success:
                self.log(f"Command failed: {output[-500:]}", "ERROR")
            return success, output
        except subprocess.TimeoutExpired:
            self.log(f"Command timed out after {timeout}s", "ERROR")
            return False, "Timeout"
        except Exception as e:
            self.log(f"Command error: {e}", "ERROR")
            return False, str(e)

    # =========================================================================
    # SETUP
    # =========================================================================

    def setup_kaggle(self) -> bool:
        """Setup Kaggle credentials."""
        self.log("Setting up Kaggle credentials...")

        kaggle_dir = Path.home() / ".kaggle"
        kaggle_dir.mkdir(exist_ok=True)

        kaggle_json = kaggle_dir / "kaggle.json"
        if not kaggle_json.exists():
            # Use environment variables or default
            username = os.environ.get("KAGGLE_USERNAME", "vincenzorubino")
            key = os.environ.get("KAGGLE_KEY", "640efdd03a8156e5a8530f3c47106323")

            with open(kaggle_json, "w") as f:
                json.dump({"username": username, "key": key}, f)

            kaggle_json.chmod(0o600)

        self.log("Kaggle credentials configured")
        return True

    def setup_environment(self) -> bool:
        """Setup full environment."""
        self.log("=" * 60)
        self.log("VESUVIUS WINNER AGENT - SETUP")
        self.log("=" * 60)

        # Check GPU
        success, output = self.run_command("nvidia-smi")
        if not success:
            self.log("No GPU detected!", "ERROR")
            return False
        self.log("GPU detected")

        # Setup Kaggle
        self.setup_kaggle()

        # Install dependencies
        self.log("Installing dependencies...")
        deps = [
            "pip install -q kaggle",
            "pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121",
            "pip install -q monai[all] nnunetv2 nibabel tifffile tqdm scipy"
        ]
        for dep in deps:
            self.run_command(dep, timeout=600)

        self.log("Environment setup complete")
        return True

    def download_data(self) -> bool:
        """Download competition data."""
        if (DATA_DIR / "train_images").exists():
            self.log("Data already downloaded")
            return True

        self.log("Downloading competition data...")
        success, _ = self.run_command(
            f"kaggle competitions download -c {COMPETITION} -p {DATA_DIR}",
            timeout=1800
        )

        if success:
            # Unzip
            self.run_command(f"unzip -q -o {DATA_DIR}/*.zip -d {DATA_DIR}")
            self.log("Data downloaded and extracted")

        return success

    # =========================================================================
    # TRAINING
    # =========================================================================

    def prepare_nnunet_data(self, dataset_id: int = 100) -> bool:
        """Prepare data in nnUNet format."""
        self.log("Preparing nnUNet dataset...")

        dataset_name = f"Dataset{dataset_id:03d}_VesuviusSurface"
        dataset_dir = NNUNET_RAW / dataset_name
        images_dir = dataset_dir / "imagesTr"
        labels_dir = dataset_dir / "labelsTr"

        images_dir.mkdir(parents=True, exist_ok=True)
        labels_dir.mkdir(parents=True, exist_ok=True)

        # Link training data
        train_images = DATA_DIR / "train_images"
        train_labels = DATA_DIR / "train_labels"

        if not train_images.exists():
            self.log(f"Training images not found: {train_images}", "ERROR")
            return False

        import tifffile

        image_files = sorted(train_images.glob("*.tif"))
        self.log(f"Found {len(image_files)} training cases")

        for img_path in image_files:
            case_id = img_path.stem
            label_path = train_labels / img_path.name

            if not label_path.exists():
                continue

            # Symlink image
            img_dst = images_dir / f"{case_id}_0000.tif"
            if not img_dst.exists():
                img_dst.symlink_to(img_path.resolve())

            # Symlink label
            lbl_dst = labels_dir / f"{case_id}.tif"
            if not lbl_dst.exists():
                lbl_dst.symlink_to(label_path.resolve())

        # Create dataset.json
        dataset_json = {
            "channel_names": {"0": "CT"},
            "labels": {"background": 0, "surface": 1, "ignore": 2},
            "numTraining": len(list(images_dir.glob("*.tif"))),
            "file_ending": ".tif",
            "overwrite_image_reader_writer": "SimpleTiffIO"
        }

        with open(dataset_dir / "dataset.json", "w") as f:
            json.dump(dataset_json, f, indent=4)

        self.log(f"nnUNet dataset prepared: {dataset_json['numTraining']} cases")
        return True

    def train_nnunet(self, dataset_id: int = 100, epochs: int = 250,
                     configuration: str = "3d_fullres",
                     planner: str = "nnUNetPlannerResEncM") -> bool:
        """Train nnUNet model."""
        self.log("=" * 60)
        self.log(f"TRAINING nnUNet - {epochs} epochs")
        self.log("=" * 60)

        # Preprocessing
        self.log("Running preprocessing...")
        plans_name = "nnUNetResEncUNetMPlans" if "ResEnc" in planner else "nnUNetPlans"

        preprocess_cmd = f"nnUNetv2_plan_and_preprocess -d {dataset_id:03d} -pl {planner} -c {configuration} -np 4"
        success, _ = self.run_command(preprocess_cmd, timeout=7200)

        if not success:
            self.log("Preprocessing failed", "ERROR")
            return False

        # Training
        self.log("Starting training...")
        trainer = f"nnUNetTrainer_{epochs}epochs" if epochs != 1000 else "nnUNetTrainer"
        train_cmd = f"nnUNetv2_train {dataset_id:03d} {configuration} all -p {plans_name} -tr {trainer}"

        success, _ = self.run_command(train_cmd, timeout=28800)  # 8 hours

        if success:
            self.log("Training complete!")
        else:
            self.log("Training failed", "ERROR")

        return success

    # =========================================================================
    # INFERENCE
    # =========================================================================

    def run_inference(self, dataset_id: int = 100,
                      configuration: str = "3d_fullres",
                      epochs: int = 250) -> bool:
        """Run inference on test set."""
        self.log("Running inference...")

        # Prepare test input
        test_input = WORKSPACE / "test_input"
        test_input.mkdir(exist_ok=True)

        test_images = DATA_DIR / "test_images"
        if test_images.exists():
            for img in test_images.glob("*.tif"):
                dst = test_input / f"{img.stem}_0000.tif"
                if not dst.exists():
                    dst.symlink_to(img.resolve())

        # Run nnUNet inference
        plans_name = "nnUNetResEncUNetMPlans"
        trainer = f"nnUNetTrainer_{epochs}epochs" if epochs != 1000 else "nnUNetTrainer"

        pred_cmd = f"nnUNetv2_predict -d {dataset_id:03d} -c {configuration} -f all "
        pred_cmd += f"-i {test_input} -o {PREDICTIONS_DIR} -p {plans_name} -tr {trainer} "
        pred_cmd += "--save_probabilities"

        success, _ = self.run_command(pred_cmd, timeout=14400)  # 4 hours

        if success:
            self.log("Inference complete!")
        else:
            self.log("Inference failed", "ERROR")

        return success

    def postprocess_predictions(self) -> bool:
        """Apply post-processing to predictions."""
        self.log("Post-processing predictions...")

        try:
            import numpy as np
            import tifffile
            from models.postprocessing import postprocess_vesuvius

            output_dir = PREDICTIONS_DIR / "postprocessed"
            output_dir.mkdir(exist_ok=True)

            for pred_file in PREDICTIONS_DIR.glob("*.npz"):
                case_id = pred_file.stem
                data = np.load(pred_file)
                probs = data['probabilities']

                # Get surface class (class 1)
                pred = (probs[1] > 0.5).astype(np.uint8)

                # Post-process
                pred = postprocess_vesuvius(
                    pred,
                    min_component_size=100,
                    max_hole_size=500,
                    closing_iterations=2,
                    opening_iterations=1
                )

                # Save
                tifffile.imwrite(output_dir / f"{case_id}.tif", pred)

            self.log("Post-processing complete!")
            return True

        except Exception as e:
            self.log(f"Post-processing error: {e}", "ERROR")
            return False

    # =========================================================================
    # SUBMISSION
    # =========================================================================

    def create_submission(self) -> Optional[Path]:
        """Create submission ZIP file."""
        self.log("Creating submission...")

        import zipfile

        # Find predictions
        pred_dir = PREDICTIONS_DIR / "postprocessed"
        if not pred_dir.exists():
            pred_dir = PREDICTIONS_DIR

        tiff_files = list(pred_dir.glob("*.tif"))
        if not tiff_files:
            self.log("No predictions found!", "ERROR")
            return None

        # Create ZIP
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_path = SUBMISSIONS_DIR / f"submission_{timestamp}.zip"

        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for tif in tiff_files:
                zf.write(tif, tif.name)

        self.log(f"Submission created: {zip_path}")
        self.log(f"Size: {zip_path.stat().st_size / (1024*1024):.1f} MB")

        return zip_path

    def submit(self, submission_path: Path, message: str = "Auto submission") -> bool:
        """Submit to Kaggle."""
        if self.submissions_today >= MAX_SUBMISSIONS_PER_DAY:
            self.log(f"Daily submission limit reached ({MAX_SUBMISSIONS_PER_DAY})", "WARNING")
            return False

        self.log(f"Submitting: {submission_path}")
        cmd = f'kaggle competitions submit -c {COMPETITION} -f {submission_path} -m "{message}"'
        success, output = self.run_command(cmd)

        if success:
            self.submissions_today += 1
            self.log(f"Submitted! ({self.submissions_today}/{MAX_SUBMISSIONS_PER_DAY} today)")
        else:
            self.log("Submission failed", "ERROR")

        return success

    def check_leaderboard(self) -> Optional[float]:
        """Check current leaderboard score."""
        cmd = f"kaggle competitions submissions -c {COMPETITION} --csv"
        success, output = self.run_command(cmd)

        if success and "publicScore" in output:
            lines = output.strip().split("\n")
            if len(lines) > 1:
                # Parse CSV
                header = lines[0].split(",")
                if "publicScore" in header:
                    score_idx = header.index("publicScore")
                    latest = lines[1].split(",")
                    score = float(latest[score_idx])
                    self.log(f"Latest LB score: {score}")
                    return score

        return None

    # =========================================================================
    # MAIN PIPELINE
    # =========================================================================

    def run_full_pipeline(self, skip_training: bool = False) -> bool:
        """Run the complete pipeline."""
        self.log("=" * 60)
        self.log("VESUVIUS WINNER AGENT - FULL PIPELINE")
        self.log(f"Target: {TARGET_SCORE} | Prize: ${PRIZE:,}")
        self.log("=" * 60)

        # 1. Setup
        if not self.setup_environment():
            return False

        # 2. Download data
        if not self.download_data():
            return False

        # 3. Prepare nnUNet data
        if not self.prepare_nnunet_data():
            return False

        # 4. Train (optional)
        if not skip_training:
            if not self.train_nnunet(epochs=250):
                return False

        # 5. Inference
        if not self.run_inference():
            return False

        # 6. Post-process
        self.postprocess_predictions()

        # 7. Create submission
        submission = self.create_submission()
        if not submission:
            return False

        # 8. Submit
        self.submit(submission, "nnUNet ResEnc + postprocessing")

        # 9. Check score
        score = self.check_leaderboard()
        if score:
            if score < TARGET_SCORE:
                self.log(f"NEW BEST! {score} < {TARGET_SCORE}", "SUCCESS")
            else:
                gap = score - TARGET_SCORE
                self.log(f"Gap to target: {gap:.4f}")

        self.log("=" * 60)
        self.log("PIPELINE COMPLETE")
        self.log("=" * 60)

        return True

    def status(self):
        """Print current status."""
        self.log("=" * 60)
        self.log("VESUVIUS AGENT STATUS")
        self.log("=" * 60)

        # GPU
        self.run_command("nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv")

        # Disk
        self.run_command("df -h /workspace")

        # Data
        if DATA_DIR.exists():
            self.log(f"Data: {DATA_DIR} exists")
        else:
            self.log("Data: NOT DOWNLOADED")

        # Models
        model_dirs = list(NNUNET_RESULTS.glob("Dataset*"))
        if model_dirs:
            self.log(f"Models: {len(model_dirs)} found")
        else:
            self.log("Models: NONE")

        # Submissions
        if SUBMISSIONS_DIR.exists():
            subs = list(SUBMISSIONS_DIR.glob("*.zip"))
            self.log(f"Submissions: {len(subs)} created")

        # Leaderboard
        self.check_leaderboard()


def main():
    parser = argparse.ArgumentParser(description="Vesuvius Winner Agent")
    parser.add_argument("command", choices=["run", "setup", "train", "infer", "submit", "status"],
                        help="Command to execute")
    parser.add_argument("--skip-training", action="store_true", help="Skip training phase")
    parser.add_argument("--epochs", type=int, default=250, help="Training epochs")

    args = parser.parse_args()

    agent = VesuviusAgent(verbose=True)

    if args.command == "run":
        agent.run_full_pipeline(skip_training=args.skip_training)
    elif args.command == "setup":
        agent.setup_environment()
        agent.download_data()
    elif args.command == "train":
        agent.prepare_nnunet_data()
        agent.train_nnunet(epochs=args.epochs)
    elif args.command == "infer":
        agent.run_inference()
        agent.postprocess_predictions()
    elif args.command == "submit":
        sub = agent.create_submission()
        if sub:
            agent.submit(sub)
    elif args.command == "status":
        agent.status()


if __name__ == "__main__":
    main()
