#!/usr/bin/env python3
"""
Inference Server - Provides trading signals from V9 model
==========================================================
HTTP API for the trading engine to get entry/exit signals.
"""

import json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from pathlib import Path
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
import threading
import time
import requests
import joblib

# Configuration
MODEL_DIR = Path('/var/www/html/pippo.cuttalo.com/models')
PORT = 3071
BINANCE_API = 'https://api.binance.com/api/v3'

# Global state
model = None
feature_names = None
price_cache = []
cache_lock = threading.Lock()


def load_model():
    """Load the trained model."""
    global model, feature_names

    latest_path = MODEL_DIR / 'latest'
    if not latest_path.exists():
        print("No model found!")
        return False

    model_path = latest_path.resolve()
    model = joblib.load(model_path / 'model.joblib')

    with open(model_path / 'features.json') as f:
        feature_names = json.load(f)

    print(f"Loaded model from {model_path}")
    return True


def fetch_recent_prices(symbol='BTCEUR', limit=500):
    """Fetch recent price data from Binance."""
    try:
        resp = requests.get(f'{BINANCE_API}/klines', params={
            'symbol': symbol,
            'interval': '1m',
            'limit': limit
        }, timeout=10)

        if resp.ok:
            data = resp.json()
            prices = []
            for k in data:
                prices.append({
                    'timestamp': datetime.fromtimestamp(k[0] / 1000),
                    'open': float(k[1]),
                    'high': float(k[2]),
                    'low': float(k[3]),
                    'close': float(k[4]),
                    'volume': float(k[5])
                })
            return prices
    except Exception as e:
        print(f"Error fetching prices: {e}")
    return []


def update_price_cache():
    """Update the price cache periodically."""
    global price_cache

    while True:
        try:
            prices = fetch_recent_prices()
            if prices:
                with cache_lock:
                    price_cache = prices
                print(f"Cache updated: {len(prices)} candles, latest: {prices[-1]['timestamp']}")
        except Exception as e:
            print(f"Error updating cache: {e}")

        time.sleep(60)  # Update every minute


def create_features(df):
    """Create features for prediction."""
    close = df['close']
    high = df['high']
    low = df['low']
    volume = df['volume']
    open_ = df['open']

    features = pd.DataFrame(index=df.index)

    LOOKBACK_PERIODS = [5, 10, 20, 50, 100, 200]

    for period in LOOKBACK_PERIODS:
        features[f'return_{period}'] = close.pct_change(period)
        features[f'volatility_{period}'] = close.pct_change().rolling(period).std()
        features[f'momentum_{period}'] = close / close.shift(period) - 1

    for period in [10, 20, 50, 100, 200]:
        sma = close.rolling(period).mean()
        features[f'price_vs_sma_{period}'] = close / sma - 1

    sma_10 = close.rolling(10).mean()
    sma_20 = close.rolling(20).mean()
    sma_50 = close.rolling(50).mean()
    sma_100 = close.rolling(100).mean()
    sma_200 = close.rolling(200).mean()

    features['sma_10_50_cross'] = (sma_10 > sma_50).astype(int)
    features['sma_20_100_cross'] = (sma_20 > sma_100).astype(int)
    features['sma_50_200_cross'] = (sma_50 > sma_200).astype(int)

    delta = close.diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
    rs = gain / loss
    features['rsi'] = 100 - (100 / (1 + rs))
    features['rsi_oversold'] = (features['rsi'] < 30).astype(int)
    features['rsi_overbought'] = (features['rsi'] > 70).astype(int)

    exp1 = close.ewm(span=12, adjust=False).mean()
    exp2 = close.ewm(span=26, adjust=False).mean()
    macd = exp1 - exp2
    macd_signal = macd.ewm(span=9, adjust=False).mean()
    macd_hist = macd - macd_signal
    features['macd'] = macd
    features['macd_signal'] = macd_signal
    features['macd_hist'] = macd_hist
    features['macd_cross_up'] = ((macd > macd_signal) & (macd.shift(1) <= macd_signal.shift(1))).astype(int)
    features['macd_cross_down'] = ((macd < macd_signal) & (macd.shift(1) >= macd_signal.shift(1))).astype(int)

    bb_mid = close.rolling(20).mean()
    bb_std = close.rolling(20).std()
    bb_upper = bb_mid + (bb_std * 2)
    bb_lower = bb_mid - (bb_std * 2)
    features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 0.0001)
    features['bb_width'] = (bb_upper - bb_lower) / bb_mid
    features['price_vs_bb_upper'] = close / bb_upper - 1
    features['price_vs_bb_lower'] = close / bb_lower - 1

    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    atr = tr.rolling(14).mean()
    features['atr'] = atr
    features['atr_pct'] = atr / close

    vol_sma = volume.rolling(20).mean()
    features['volume_ratio'] = volume / vol_sma
    features['volume_change'] = volume.pct_change()

    features['higher_high'] = (high > high.shift(1)).astype(int)
    features['lower_low'] = (low < low.shift(1)).astype(int)
    features['body_size'] = abs(close - open_) / (high - low + 0.0001)
    features['upper_wick'] = (high - pd.concat([close, open_], axis=1).max(axis=1)) / (high - low + 0.0001)
    features['lower_wick'] = (pd.concat([close, open_], axis=1).min(axis=1) - low) / (high - low + 0.0001)

    plus_dm = high.diff()
    minus_dm = low.diff()
    plus_dm[plus_dm < 0] = 0
    minus_dm[minus_dm > 0] = 0
    tr_sum = tr.rolling(14).sum()
    plus_di = 100 * (plus_dm.rolling(14).sum() / tr_sum)
    minus_di = 100 * (abs(minus_dm.rolling(14).sum()) / tr_sum)
    dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di + 0.0001)
    features['adx'] = dx.rolling(14).mean()

    features['hour'] = df['timestamp'].dt.hour
    features['day_of_week'] = df['timestamp'].dt.dayofweek
    features['is_weekend'] = (features['day_of_week'] >= 5).astype(int)

    return features


def get_signal(confidence_threshold=0.6):
    """Get trading signal from model."""
    global price_cache, model, feature_names

    if model is None or not price_cache:
        return {'signal': 'HOLD', 'confidence': 0, 'error': 'Model or data not ready'}

    with cache_lock:
        df = pd.DataFrame(price_cache)

    if len(df) < 250:
        return {'signal': 'HOLD', 'confidence': 0, 'error': 'Not enough data'}

    try:
        features = create_features(df)

        # Get last row features
        X = features[feature_names].iloc[-1:].values

        if np.isnan(X).any():
            return {'signal': 'HOLD', 'confidence': 0, 'error': 'Invalid features'}

        proba = model.predict(X)[0]
        # proba is [short, neutral, long]

        short_prob = proba[0]
        neutral_prob = proba[1]
        long_prob = proba[2]

        current_price = df['close'].iloc[-1]

        if long_prob > confidence_threshold and long_prob > short_prob:
            return {
                'signal': 'LONG',
                'confidence': float(long_prob),
                'price': current_price,
                'probabilities': {
                    'long': float(long_prob),
                    'short': float(short_prob),
                    'neutral': float(neutral_prob)
                }
            }
        elif short_prob > confidence_threshold and short_prob > long_prob:
            return {
                'signal': 'SHORT',
                'confidence': float(short_prob),
                'price': current_price,
                'probabilities': {
                    'long': float(long_prob),
                    'short': float(short_prob),
                    'neutral': float(neutral_prob)
                }
            }
        else:
            return {
                'signal': 'HOLD',
                'confidence': float(max(proba)),
                'price': current_price,
                'probabilities': {
                    'long': float(long_prob),
                    'short': float(short_prob),
                    'neutral': float(neutral_prob)
                }
            }

    except Exception as e:
        return {'signal': 'HOLD', 'confidence': 0, 'error': str(e)}


class RequestHandler(BaseHTTPRequestHandler):
    """HTTP request handler."""

    def log_message(self, format, *args):
        # Suppress default logging
        pass

    def do_GET(self):
        parsed = urlparse(self.path)
        path = parsed.path
        params = parse_qs(parsed.query)

        if path == '/entry-signal':
            conf = float(params.get('confidence_threshold', [0.6])[0])
            result = get_signal(conf)

            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(result).encode())

        elif path == '/health':
            result = {
                'status': 'ok',
                'model_loaded': model is not None,
                'cache_size': len(price_cache),
                'timestamp': datetime.now().isoformat()
            }
            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(result).encode())

        elif path == '/price':
            with cache_lock:
                if price_cache:
                    result = {
                        'price': price_cache[-1]['close'],
                        'timestamp': price_cache[-1]['timestamp'].isoformat()
                    }
                else:
                    result = {'error': 'No price data'}

            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(result).encode())

        else:
            self.send_response(404)
            self.end_headers()

    def do_POST(self):
        parsed = urlparse(self.path)
        path = parsed.path

        if path == '/check-exit':
            content_length = int(self.headers.get('Content-Length', 0))
            body = self.rfile.read(content_length)
            data = json.loads(body)

            # For now, let stop loss and take profit handle exits
            result = {'action': 'HOLD'}

            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            self.wfile.write(json.dumps(result).encode())

        else:
            self.send_response(404)
            self.end_headers()


def main():
    print("=" * 50)
    print("INFERENCE SERVER")
    print("=" * 50)

    # Load model
    if not load_model():
        print("Failed to load model!")
        return

    # Start cache update thread
    cache_thread = threading.Thread(target=update_price_cache, daemon=True)
    cache_thread.start()

    # Wait for initial cache
    print("Waiting for price data...")
    time.sleep(5)

    # Start HTTP server
    server = HTTPServer(('0.0.0.0', PORT), RequestHandler)
    print(f"Server running on port {PORT}")
    print("Endpoints:")
    print(f"  GET /entry-signal?confidence_threshold=0.6")
    print(f"  GET /health")
    print(f"  GET /price")
    print(f"  POST /check-exit")

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        print("\nShutting down...")
        server.shutdown()


if __name__ == '__main__':
    main()
