#!/usr/bin/env python3
"""
RunPod Manager - Gestione Pod GPU
Uso: python manager.py [comando] [opzioni]
"""

import os
import sys
import json
import time
import argparse
from pathlib import Path
from datetime import datetime

# Carica .env
from dotenv import load_dotenv
load_dotenv(Path(__file__).parent.parent / '.env')

import runpod

# Configura API
runpod.api_key = os.getenv('RUNPOD_API_KEY')

# GPU disponibili con prezzi spot (approssimativi)
GPU_CATALOG = {
    "RTX_3090": {"vram": 24, "spot_price": 0.22, "id": "NVIDIA GeForce RTX 3090"},
    "RTX_4090": {"vram": 24, "spot_price": 0.34, "id": "NVIDIA GeForce RTX 4090"},
    "A100_40GB": {"vram": 40, "spot_price": 0.99, "id": "NVIDIA A100 40GB"},
    "A100_80GB": {"vram": 80, "spot_price": 1.24, "id": "NVIDIA A100 80GB"},
    "H100": {"vram": 80, "spot_price": 2.49, "id": "NVIDIA H100 80GB HBM3"},
}

# Template Docker consigliati
DOCKER_TEMPLATES = {
    "pytorch": "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
    "pytorch-light": "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-runtime",
    "transformers": "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
    "stable-diffusion": "runpod/stable-diffusion:web-ui-10.2.1",
}


def list_gpus():
    """Mostra GPU disponibili con prezzi"""
    print("\n" + "="*60)
    print("GPU DISPONIBILI SU RUNPOD")
    print("="*60)

    try:
        gpus = runpod.get_gpus()

        print(f"\n{'GPU':<25} {'VRAM':<8} {'Spot $/h':<12} {'Disponibile':<12}")
        print("-"*60)

        for gpu in gpus:
            name = gpu.get('displayName', 'N/A')
            vram = gpu.get('memoryInGb', 'N/A')
            available = "Si" if gpu.get('available', False) else "No"

            # Cerca prezzo nel catalogo
            price = "N/A"
            for key, info in GPU_CATALOG.items():
                if key.replace("_", " ") in name.upper() or info["id"] in name:
                    price = f"${info['spot_price']:.2f}"
                    break

            print(f"{name:<25} {vram}GB{'':<4} {price:<12} {available:<12}")

    except Exception as e:
        print(f"Errore: {e}")
        print("\nUsando catalogo statico:")
        print(f"\n{'GPU':<20} {'VRAM':<8} {'Spot $/h':<12}")
        print("-"*45)
        for name, info in GPU_CATALOG.items():
            print(f"{name:<20} {info['vram']}GB{'':<4} ${info['spot_price']:.2f}")


def list_pods():
    """Lista pod attivi"""
    print("\n" + "="*60)
    print("POD ATTIVI")
    print("="*60)

    try:
        pods = runpod.get_pods()

        if not pods:
            print("\nNessun pod attivo.")
            return

        print(f"\n{'ID':<15} {'Nome':<20} {'GPU':<15} {'Status':<12}")
        print("-"*65)

        for pod in pods:
            pod_id = pod.get('id', 'N/A')
            name = pod.get('name', 'N/A')[:20]
            gpu = pod.get('gpuType', 'N/A')[:15]
            status = pod.get('desiredStatus', 'N/A')

            print(f"{pod_id:<15} {name:<20} {gpu:<15} {status:<12}")

    except Exception as e:
        print(f"Errore nel recupero pod: {e}")


def create_pod(name, gpu_type="RTX_4090", template="pytorch", disk_size=50, spot=True):
    """Crea un nuovo pod"""

    # Risolvi GPU ID
    gpu_id = GPU_CATALOG.get(gpu_type, {}).get("id", gpu_type)

    # Risolvi template Docker
    docker_image = DOCKER_TEMPLATES.get(template, template)

    print(f"\nCreazione pod...")
    print(f"  Nome: {name}")
    print(f"  GPU: {gpu_id}")
    print(f"  Immagine: {docker_image}")
    print(f"  Disco: {disk_size}GB")
    print(f"  Spot: {'Si' if spot else 'No'}")

    try:
        pod = runpod.create_pod(
            name=name,
            image_name=docker_image,
            gpu_type_id=gpu_id,
            cloud_type="SECURE" if not spot else "COMMUNITY",
            volume_in_gb=disk_size,
            container_disk_in_gb=20,
            ports="8888/http,22/tcp",  # Jupyter + SSH
            volume_mount_path="/workspace",
        )

        print(f"\n✓ Pod creato con successo!")
        print(f"  ID: {pod.get('id')}")
        print(f"  Status: {pod.get('desiredStatus')}")

        return pod

    except Exception as e:
        print(f"\n✗ Errore nella creazione: {e}")
        return None


def stop_pod(pod_id):
    """Ferma un pod (non lo elimina)"""
    try:
        runpod.stop_pod(pod_id)
        print(f"✓ Pod {pod_id} fermato")
    except Exception as e:
        print(f"✗ Errore: {e}")


def terminate_pod(pod_id):
    """Termina e elimina un pod"""
    try:
        runpod.terminate_pod(pod_id)
        print(f"✓ Pod {pod_id} terminato")
    except Exception as e:
        print(f"✗ Errore: {e}")


def get_pod_ssh(pod_id):
    """Ottieni comando SSH per connettersi al pod"""
    try:
        pods = runpod.get_pods()
        for pod in pods:
            if pod.get('id') == pod_id:
                runtime = pod.get('runtime', {})
                if runtime:
                    ports = runtime.get('ports', [])
                    for port in ports:
                        if port.get('privatePort') == 22:
                            ip = port.get('ip')
                            public_port = port.get('publicPort')
                            print(f"\nComando SSH:")
                            print(f"  ssh root@{ip} -p {public_port} -i ~/.ssh/id_ed25519")
                            return
                print("Pod non ancora pronto o SSH non disponibile")
                return
        print(f"Pod {pod_id} non trovato")
    except Exception as e:
        print(f"Errore: {e}")


def check_balance():
    """Controlla credito disponibile"""
    try:
        user = runpod.get_user()
        balance = user.get('currentSpendPerHr', 0)
        print(f"\nSpesa corrente: ${balance:.4f}/ora")
    except Exception as e:
        print(f"Errore nel recupero balance: {e}")


def main():
    parser = argparse.ArgumentParser(description='RunPod Manager')
    subparsers = parser.add_subparsers(dest='command', help='Comandi disponibili')

    # list-gpus
    subparsers.add_parser('list-gpus', help='Mostra GPU disponibili')

    # list-pods
    subparsers.add_parser('list-pods', help='Mostra pod attivi')

    # create
    create_parser = subparsers.add_parser('create', help='Crea nuovo pod')
    create_parser.add_argument('name', help='Nome del pod')
    create_parser.add_argument('--gpu', default='RTX_4090',
                              choices=list(GPU_CATALOG.keys()),
                              help='Tipo di GPU')
    create_parser.add_argument('--template', default='pytorch',
                              choices=list(DOCKER_TEMPLATES.keys()),
                              help='Template Docker')
    create_parser.add_argument('--disk', type=int, default=50,
                              help='Dimensione disco GB')
    create_parser.add_argument('--on-demand', action='store_true',
                              help='Usa on-demand invece di spot')

    # stop
    stop_parser = subparsers.add_parser('stop', help='Ferma pod')
    stop_parser.add_argument('pod_id', help='ID del pod')

    # terminate
    term_parser = subparsers.add_parser('terminate', help='Termina pod')
    term_parser.add_argument('pod_id', help='ID del pod')

    # ssh
    ssh_parser = subparsers.add_parser('ssh', help='Mostra comando SSH')
    ssh_parser.add_argument('pod_id', help='ID del pod')

    # balance
    subparsers.add_parser('balance', help='Mostra credito')

    args = parser.parse_args()

    if not runpod.api_key:
        print("ERRORE: RUNPOD_API_KEY non configurata!")
        print("Copia .env.example in .env e inserisci la tua API key")
        sys.exit(1)

    if args.command == 'list-gpus':
        list_gpus()
    elif args.command == 'list-pods':
        list_pods()
    elif args.command == 'create':
        create_pod(args.name, args.gpu, args.template, args.disk, not args.on_demand)
    elif args.command == 'stop':
        stop_pod(args.pod_id)
    elif args.command == 'terminate':
        terminate_pod(args.pod_id)
    elif args.command == 'ssh':
        get_pod_ssh(args.pod_id)
    elif args.command == 'balance':
        check_balance()
    else:
        parser.print_help()


if __name__ == '__main__':
    main()
