#!/usr/bin/env python3
"""
Multi-Trader Paper Trading Engine
=================================

Gestisce multipli trader in parallelo per paper trading.
Ogni trader ha la sua configurazione e stato indipendente.
"""

import asyncio
import json
import signal
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import requests
import psycopg2
from psycopg2.extras import RealDictCursor

# Configuration
DB_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'dbname': 'pippo',
    'user': 'pippo',
    'password': 'pippo_trading_2026'
}

BINANCE_API = 'https://api.binance.com/api/v3'
INFERENCE_URL = 'http://localhost:3071'
CHECK_INTERVAL = 60  # seconds

# Global state
running = True
traders: Dict[int, Dict[str, Any]] = {}


def log(msg: str, trader_id: Optional[int] = None):
    timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    prefix = f"[Trader {trader_id}]" if trader_id else "[Engine]"
    print(f"[{timestamp}] {prefix} {msg}")


def get_db_connection():
    return psycopg2.connect(**DB_CONFIG, cursor_factory=RealDictCursor)


def get_current_price(symbol='BTCEUR') -> float:
    try:
        resp = requests.get(f'{BINANCE_API}/ticker/price?symbol={symbol}', timeout=5)
        if resp.ok:
            return float(resp.json()['price'])
    except Exception as e:
        log(f"Error fetching price: {e}")
    return 0


def load_active_traders():
    """Load all active (paper/live) traders from database."""
    global traders

    try:
        conn = get_db_connection()
        cur = conn.cursor()

        cur.execute("""
            SELECT t.*, ts.position, ts.position_amount, ts.entry_price, ts.entry_time
            FROM traders t
            LEFT JOIN trader_states ts ON t.id = ts.trader_id
            WHERE t.status IN ('paper', 'live')
        """)

        rows = cur.fetchall()
        cur.close()
        conn.close()

        for row in rows:
            trader_id = row['id']
            if trader_id not in traders:
                traders[trader_id] = {
                    'id': trader_id,
                    'name': row['name'],
                    'status': row['status'],
                    'initial_capital': float(row['initial_capital']),
                    'current_capital': float(row['current_capital']),
                    'position_size': float(row['position_size']),
                    'stop_loss': float(row['stop_loss']),
                    'take_profit': float(row['take_profit']),
                    'confidence_threshold': float(row['confidence_threshold']),
                    'pair': row['pair'],
                    'fee_rate': float(row['fee_rate']),
                    'position': row['position'] or 0,
                    'position_amount': float(row['position_amount'] or 0),
                    'entry_price': float(row['entry_price'] or 0),
                    'entry_time': row['entry_time'],
                    'last_check': None,
                }
                log(f"Loaded trader: {row['name']}", trader_id)

        # Remove stopped traders
        active_ids = {row['id'] for row in rows}
        for tid in list(traders.keys()):
            if tid not in active_ids:
                log(f"Trader stopped, removing", tid)
                del traders[tid]

    except Exception as e:
        log(f"Error loading traders: {e}")


def save_trader_state(trader_id: int, state: Dict[str, Any]):
    """Save trader state to database."""
    try:
        conn = get_db_connection()
        cur = conn.cursor()

        cur.execute("""
            INSERT INTO trader_states (trader_id, position, position_amount, entry_price, entry_time, updated_at)
            VALUES (%s, %s, %s, %s, %s, NOW())
            ON CONFLICT (trader_id) DO UPDATE SET
                position = EXCLUDED.position,
                position_amount = EXCLUDED.position_amount,
                entry_price = EXCLUDED.entry_price,
                entry_time = EXCLUDED.entry_time,
                updated_at = NOW()
        """, (
            trader_id,
            state['position'],
            state['position_amount'],
            state['entry_price'],
            state['entry_time']
        ))

        # Update current capital
        cur.execute("""
            UPDATE traders SET current_capital = %s, updated_at = NOW() WHERE id = %s
        """, (state['current_capital'], trader_id))

        conn.commit()
        cur.close()
        conn.close()
    except Exception as e:
        log(f"Error saving state: {e}", trader_id)


def save_trade(trader_id: int, trade: Dict[str, Any], mode: str):
    """Save trade to database."""
    try:
        conn = get_db_connection()
        cur = conn.cursor()

        cur.execute("""
            INSERT INTO trades (
                trader_id, mode, pair, direction, action,
                entry_price, exit_price, amount,
                fee_entry, fee_exit, gross_pnl, net_pnl, pnl_percent,
                exit_reason, entry_time, exit_time, duration_minutes,
                capital_before, capital_after
            ) VALUES (
                %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
            )
        """, (
            trader_id,
            mode,
            trade.get('pair', 'BTC/EUR'),
            trade.get('direction'),
            trade.get('action'),
            trade.get('entry_price'),
            trade.get('exit_price'),
            trade.get('amount'),
            trade.get('fee_entry', 0),
            trade.get('fee_exit', 0),
            trade.get('gross_pnl'),
            trade.get('net_pnl'),
            trade.get('pnl_percent'),
            trade.get('exit_reason'),
            trade.get('entry_time'),
            trade.get('exit_time'),
            trade.get('duration_minutes'),
            trade.get('capital_before'),
            trade.get('capital_after'),
        ))

        conn.commit()
        cur.close()
        conn.close()
    except Exception as e:
        log(f"Error saving trade: {e}", trader_id)


def get_entry_signal(trader: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    """Get entry signal from inference server."""
    try:
        resp = requests.get(f'{INFERENCE_URL}/entry-signal', params={
            'trader_id': trader['id'],
            'confidence_threshold': trader['confidence_threshold']
        }, timeout=10)
        return resp.json()
    except Exception as e:
        log(f"Error getting signal: {e}", trader['id'])
        return None


def check_exit(trader: Dict[str, Any], current_price: float) -> Optional[Dict[str, Any]]:
    """Check exit conditions."""
    if trader['position'] == 0:
        return None

    entry_price = trader['entry_price']
    direction = trader['position']

    # Calculate P&L
    if direction == 1:  # Long
        pnl_pct = (current_price - entry_price) / entry_price
    else:  # Short
        pnl_pct = (entry_price - current_price) / entry_price

    # Check stop loss
    if pnl_pct <= -trader['stop_loss']:
        return {'action': 'EXIT', 'reason': 'stop_loss', 'pnl_pct': pnl_pct}

    # Check take profit
    if pnl_pct >= trader['take_profit']:
        return {'action': 'EXIT', 'reason': 'take_profit', 'pnl_pct': pnl_pct}

    # Check with inference server for model exit
    try:
        resp = requests.post(f'{INFERENCE_URL}/check-exit', json={
            'trader_id': trader['id'],
            'current_price': current_price,
            'entry_price': entry_price,
            'direction': direction
        }, timeout=10)
        result = resp.json()
        if result.get('action') == 'EXIT':
            return {'action': 'EXIT', 'reason': 'model', 'pnl_pct': pnl_pct}
    except:
        pass

    return None


def execute_entry(trader: Dict[str, Any], price: float, direction: int, confidence: float):
    """Execute entry trade."""
    trade_value = trader['current_capital'] * trader['position_size']
    fee = trade_value * trader['fee_rate'] / 2  # Half fee on entry

    slippage = price * 0.0003
    if direction == 1:  # Long
        exec_price = price + slippage
    else:  # Short
        exec_price = price - slippage

    amount = (trade_value - fee) / exec_price

    capital_before = trader['current_capital']
    trader['current_capital'] -= trade_value
    trader['position'] = direction
    trader['position_amount'] = amount
    trader['entry_price'] = exec_price
    trader['entry_time'] = datetime.now()

    # Save trade
    save_trade(trader['id'], {
        'pair': trader['pair'],
        'direction': 'LONG' if direction == 1 else 'SHORT',
        'action': 'OPEN',
        'entry_price': exec_price,
        'amount': amount,
        'fee_entry': fee,
        'entry_time': trader['entry_time'],
        'capital_before': capital_before,
        'capital_after': trader['current_capital'],
    }, trader['status'])

    save_trader_state(trader['id'], trader)

    dir_str = 'LONG' if direction == 1 else 'SHORT'
    log(f"{dir_str} {amount:.6f} BTC @ €{exec_price:,.2f} | Conf: {confidence:.1%}", trader['id'])


def execute_exit(trader: Dict[str, Any], price: float, reason: str):
    """Execute exit trade."""
    direction = trader['position']
    amount = trader['position_amount']
    entry_price = trader['entry_price']
    fee = amount * price * trader['fee_rate'] / 2  # Half fee on exit

    slippage = price * 0.0003
    if direction == 1:  # Close long
        exec_price = price - slippage
        gross_pnl = amount * (exec_price - entry_price)
    else:  # Close short
        exec_price = price + slippage
        gross_pnl = amount * (entry_price - exec_price)

    net_pnl = gross_pnl - fee
    pnl_pct = net_pnl / (amount * entry_price)

    capital_before = trader['current_capital']
    trader['current_capital'] += amount * entry_price + net_pnl

    duration = None
    if trader['entry_time']:
        duration = int((datetime.now() - trader['entry_time']).total_seconds() / 60)

    # Save trade
    save_trade(trader['id'], {
        'pair': trader['pair'],
        'direction': 'LONG' if direction == 1 else 'SHORT',
        'action': 'CLOSE',
        'entry_price': entry_price,
        'exit_price': exec_price,
        'amount': amount,
        'fee_exit': fee,
        'gross_pnl': gross_pnl,
        'net_pnl': net_pnl,
        'pnl_percent': pnl_pct * 100,
        'exit_reason': reason,
        'entry_time': trader['entry_time'],
        'exit_time': datetime.now(),
        'duration_minutes': duration,
        'capital_before': capital_before,
        'capital_after': trader['current_capital'],
    }, trader['status'])

    # Reset position
    trader['position'] = 0
    trader['position_amount'] = 0
    trader['entry_price'] = 0
    trader['entry_time'] = None

    save_trader_state(trader['id'], trader)

    emoji = '+' if net_pnl > 0 else ''
    log(f"CLOSE ({reason}) @ €{exec_price:,.2f} | P&L: {emoji}€{net_pnl:.2f} ({pnl_pct*100:+.2f}%)", trader['id'])


async def process_trader(trader: Dict[str, Any], price: float):
    """Process a single trader."""
    trader_id = trader['id']

    try:
        # If no position, check for entry
        if trader['position'] == 0:
            signal = get_entry_signal(trader)

            if signal and signal.get('signal') in ['LONG', 'SHORT']:
                direction = 1 if signal['signal'] == 'LONG' else -1
                confidence = signal.get('confidence', 0)
                execute_entry(trader, price, direction, confidence)

        # If in position, check for exit
        else:
            exit_signal = check_exit(trader, price)

            if exit_signal and exit_signal.get('action') == 'EXIT':
                execute_exit(trader, price, exit_signal.get('reason', 'unknown'))

    except Exception as e:
        log(f"Error processing: {e}", trader_id)


async def main_loop():
    """Main trading loop."""
    global running

    log("Starting Multi-Trader Engine")
    log("=" * 50)

    while running:
        try:
            # Reload active traders
            load_active_traders()

            if not traders:
                await asyncio.sleep(10)
                continue

            # Get current price
            price = get_current_price()
            if price <= 0:
                log("Could not get price, retrying...")
                await asyncio.sleep(30)
                continue

            # Process all traders
            tasks = [process_trader(t, price) for t in traders.values()]
            await asyncio.gather(*tasks)

            # Log status periodically
            now = datetime.now()
            if now.minute % 5 == 0 and now.second < CHECK_INTERVAL:
                active = sum(1 for t in traders.values() if t['position'] != 0)
                total_capital = sum(t['current_capital'] for t in traders.values())
                log(f"Status: {len(traders)} traders, {active} in position, €{total_capital:,.2f} capital")

            await asyncio.sleep(CHECK_INTERVAL)

        except asyncio.CancelledError:
            break
        except Exception as e:
            log(f"Error in main loop: {e}")
            await asyncio.sleep(30)

    log("Engine stopped")


def signal_handler(sig, frame):
    global running
    log("Shutting down...")
    running = False


if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    asyncio.run(main_loop())
