#!/usr/bin/env python3
"""
Convert PyTorch regime models to orchestrator JSON format and load into database.
"""

import torch
import json
import psycopg2
from pathlib import Path
import numpy as np

# Database config
DB_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'dbname': 'bestrading',
    'user': 'bestrading',
    'password': 'UQyvjfZIvUtpqlksPfKeq2MmXgGiG3y5'
}

# Model paths
MODEL_DIR = Path('/var/www/html/bestrading.cuttalo.com/models/btc_v6')
PAIR = 'BTC/EUR'

def load_pytorch_model(path: str) -> dict:
    """Load PyTorch model and extract weights."""
    checkpoint = torch.load(path, map_location='cpu')
    model_state = checkpoint['model_state']
    config = checkpoint.get('config', {})

    return {
        'state_dict': model_state,
        'config': config,
        'input_dim': checkpoint.get('input_dim', 9),
        'regime': checkpoint.get('regime', 'unknown')
    }

def convert_to_orchestrator_format(state_dict: dict, input_dim: int) -> dict:
    """
    Convert PyTorch state dict to orchestrator neural network format.

    Orchestrator expects:
    {
        "type": "PureNN",
        "network": {
            "layers": [
                {"weights": [[...]], "biases": [...], "activation": "leaky_relu"},
                ...
            ]
        }
    }
    """
    layers = []

    # Extract layers from state dict
    # Our model has: embed, transformer, actor_mean, critic
    # We need to convert transformer to simple feedforward for inference

    # For now, let's create a simple 3-layer network using the available weights
    # embed.weight: (hidden_dim, input_dim+1)
    # embed.bias: (hidden_dim,)

    if 'embed.weight' in state_dict:
        embed_w = state_dict['embed.weight'].numpy().tolist()
        embed_b = state_dict['embed.bias'].numpy().tolist()

        # First hidden layer (from embed)
        layers.append({
            'weights': embed_w,
            'biases': embed_b,
            'activation': 'leaky_relu'
        })

    # Add transformer encoder layers (simplified)
    # transformer.layers.0.self_attn, etc.
    # For simplicity, we'll use a reduced representation

    # Extract some feedforward weights if available
    for i in range(2):  # 2 transformer layers
        ff1_key = f'transformer.layers.{i}.linear1.weight'
        ff1_bias = f'transformer.layers.{i}.linear1.bias'
        ff2_key = f'transformer.layers.{i}.linear2.weight'
        ff2_bias = f'transformer.layers.{i}.linear2.bias'

        if ff1_key in state_dict:
            w1 = state_dict[ff1_key].numpy()
            b1 = state_dict[ff1_bias].numpy()

            # Reduce dimensions for faster inference
            # Take subset of weights
            reduced_dim = min(64, w1.shape[0])

            layers.append({
                'weights': w1[:reduced_dim, :].tolist(),
                'biases': b1[:reduced_dim].tolist(),
                'activation': 'leaky_relu'
            })

    # Actor output layer
    if 'actor_mean.weight' in state_dict:
        actor_w = state_dict['actor_mean.weight'].numpy()
        actor_b = state_dict['actor_mean.bias'].numpy()

        # Output layer (3 classes: FLAT, LONG, SHORT)
        # Convert single output to 3-class
        output_w = np.zeros((3, actor_w.shape[1]))
        output_b = np.zeros(3)

        # FLAT: centered around 0
        output_w[0] = -actor_w[0] * 0.5
        output_b[0] = 0.0

        # LONG: positive output
        output_w[1] = actor_w[0]
        output_b[1] = actor_b[0]

        # SHORT: negative output
        output_w[2] = -actor_w[0]
        output_b[2] = -actor_b[0]

        layers.append({
            'weights': output_w.tolist(),
            'biases': output_b.tolist(),
            'activation': 'linear'
        })

    return {
        'type': 'PureNN',
        'network': {
            'layers': layers
        }
    }

def create_regime_models_json(models_dir: Path) -> dict:
    """Create the full regimeModels JSON structure."""
    regime_models = {}

    regimes = ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']

    for regime in regimes:
        model_path = models_dir / f'model_{regime}_v6.pt'

        if model_path.exists():
            print(f"  Loading {regime}...")
            model_data = load_pytorch_model(str(model_path))
            nn_format = convert_to_orchestrator_format(
                model_data['state_dict'],
                model_data['input_dim']
            )
            regime_models[regime] = nn_format
            print(f"    Layers: {len(nn_format['network']['layers'])}")
        else:
            print(f"  Warning: {model_path} not found")

    return {'regimeModels': regime_models}

def save_to_database(pair: str, weights_json: dict):
    """Save model weights to PostgreSQL database."""
    conn = psycopg2.connect(**DB_CONFIG)
    cur = conn.cursor()

    # Convert pair format: BTC/EUR -> BTC_EUR
    pair_key = pair.replace('/', '_')

    # Save individual regime models
    for regime, model in weights_json.get('regimeModels', {}).items():
        model_key = f"{pair_key}_{regime}"

        weights_str = json.dumps(model)

        cur.execute("""
            INSERT INTO ml_models (pair, weights, training_samples, accuracy, last_trained_at)
            VALUES (%s, %s, 500, 0.5, NOW())
            ON CONFLICT (pair) DO UPDATE SET
                weights = EXCLUDED.weights,
                training_samples = 500,
                last_trained_at = NOW()
        """, (model_key, weights_str))

        print(f"  Saved {model_key}")

    # Save combined model
    combined_key = f"{pair_key}_combined"
    cur.execute("""
        INSERT INTO ml_models (pair, weights, training_samples, accuracy, last_trained_at)
        VALUES (%s, %s, 500, 0.5, NOW())
        ON CONFLICT (pair) DO UPDATE SET
            weights = EXCLUDED.weights,
            training_samples = 500,
            last_trained_at = NOW()
    """, (combined_key, json.dumps(weights_json)))
    print(f"  Saved {combined_key}")

    conn.commit()
    cur.close()
    conn.close()

def main():
    print("=" * 60)
    print("🔄 Converting PyTorch models to orchestrator format")
    print("=" * 60)

    print(f"\n📂 Model directory: {MODEL_DIR}")
    print(f"📊 Target pair: {PAIR}")

    # Convert models
    print("\n⚡ Converting models...")
    regime_models = create_regime_models_json(MODEL_DIR)

    print(f"\n✅ Converted {len(regime_models['regimeModels'])} regime models")

    # Save to file for inspection
    output_file = MODEL_DIR / 'btc_eur_models.json'
    with open(output_file, 'w') as f:
        json.dump(regime_models, f, indent=2)
    print(f"\n💾 Saved to: {output_file}")

    # Save to database
    print("\n📤 Saving to database...")
    try:
        save_to_database(PAIR, regime_models)
        print("\n✅ Models saved to database!")
    except Exception as e:
        print(f"\n❌ Database error: {e}")
        print("   Models saved to JSON file only.")

    print("\n" + "=" * 60)
    print("🎉 Conversion complete!")
    print("=" * 60)

if __name__ == '__main__':
    main()
