import os
import pandas as pd
import numpy as np
import pywt
import scipy.signal as signal
from collections import deque
from datetime import datetime
import time
from types import SimpleNamespace

base_dir = os.path.dirname(os.path.abspath(__file__))  # app papkasining yo'li
data_file = os.path.join(base_dir, 'data_puls')

# Papka mavjud bo‘lmasa yaratish
os.makedirs(data_file, exist_ok=True)

# Deque’lar diagnostika uchun
last_10s = deque()
last_5min = deque()
last_1h = deque()
last_level2 = deque()

DIAG1_INTERVAL = 10
DIAG2_INTERVAL = 5*60
DIAG3_INTERVAL = 60*60

# Level 1 diagnostika
def diagnose_level1(f):
    if f.HRV < 0.03:
        return "Me'yorda"
    if f.RR < 0.6:
        return "Taxikardiya (yurak tez)"
    if f.RR > 1.0:
        return "Bradikardiya (yurak sekin)"
    if abs(f.ST) > 0.2:
        return "Ishemiya ehtimoli"
    if f.Entropy > 1.0:
        return "Aritmiya ehtimoli"
    return "Me'yorda"

# Level 2 diagnostika
def diagnose_level2_from_window(window):
    n = len(window)
    if n == 0: return "Me'yorda"
    cnt_ar = cnt_taxi = cnt_bradi = cnt_isch = 0
    hrv_mean = ent_mean = 0
    for f in window:
        l1 = diagnose_level1(f)
        if "Aritmiya" in l1: cnt_ar += 1
        if "Taxikardiya" in l1: cnt_taxi += 1
        if "Bradikardiya" in l1: cnt_bradi += 1
        if "Ishemiya" in l1: cnt_isch += 1
        hrv_mean += f.HRV
        ent_mean += f.Entropy
    hrv_mean /= max(1, n)
    ent_mean /= max(1, n)
    if cnt_isch >= max(1, n // 10): return "O'tkinchi ishemiya"
    if cnt_ar >= max(1, n // 5): return "Takroriy aritmiya"
    if cnt_taxi >= max(1, n // 3): return "Doimiy taxikardiya"
    if cnt_bradi >= max(1, n // 3): return "Doimiy bradikardiya"
    if hrv_mean > 0.08 and ent_mean > 0.9: return "Atrial fibrillyatsiya ehtimoli"
    return "Me'yorda"

# Level 3 diagnostika
def diagnose_level3_from_history(level2_labels, window360):
    from collections import Counter
    cnt = Counter(level2_labels)
    if cnt.get("O'tkinchi ishemiya",0) >= max(1,len(level2_labels)//4): return "Surunkali ishemiya"
    if cnt.get("Atrial fibrillyatsiya ehtimoli",0) >= max(1,len(level2_labels)//6): return "Kronik AF"
    if cnt.get("Doimiy taxikardiya",0) >= max(1,len(level2_labels)//3): return "Surunkali taxikardiya"
    var_mean = sum(getattr(f,'Variance',0) for f in window360)/max(1,len(window360))
    if var_mean > 0.02: return "Qattiq aritmiya / noto'g'ri ritm"
    return "Me'yorda"

# Feature extraction
def extract_features(file_path):
    df = pd.read_csv(file_path)
    signal_data = df['Heartbeats'].dropna().astype(float).values
    fs = 250

    # Symlet filtering
    sym_coeffs = pywt.wavedec(signal_data,'sym4',level=4)
    sym_thresh = np.median(np.abs(sym_coeffs[-1]))/0.6745*np.sqrt(2*np.log(len(signal_data)))
    sym_coeffs_thresh = [sym_coeffs[0]] + [pywt.threshold(c, sym_thresh, 'soft') for c in sym_coeffs[1:]]
    sym_filtered = pywt.waverec(sym_coeffs_thresh,'sym4')[:len(signal_data)]

    # R peak detection
    r_peaks,_ = signal.find_peaks(sym_filtered, distance=fs*0.6, height=np.mean(sym_filtered)+np.std(sym_filtered))
    if len(r_peaks) < 2: return []

    rr_intervals = np.diff(r_peaks)/fs
    hrv_sdnn = np.std(rr_intervals)

    # Coiflet filtering
    coif_coeffs = pywt.wavedec(signal_data,'coif2',level=4)
    coif_thresh = np.median(np.abs(coif_coeffs[-1]))/0.6745*np.sqrt(2*np.log(len(signal_data)))
    coif_coeffs_thresh = [coif_coeffs[0]] + [pywt.threshold(c, coif_thresh, 'soft') for c in coif_coeffs[1:]]
    coif_filtered = pywt.waverec(coif_coeffs_thresh,'coif2')[:len(signal_data)]

    features = []
    for i in range(1, len(r_peaks)):
        r1, r2 = r_peaks[i-1], r_peaks[i]
        rr = (r2-r1)/fs
        r_peak = r2
        qrs_duration = 0.1
        p_region = coif_filtered[max(0,r_peak-int(0.2*fs)): r_peak-int(0.1*fs)]
        p_amp = np.max(p_region) if len(p_region)>0 else 0
        t_region = coif_filtered[r_peak+int(0.1*fs): min(len(coif_filtered), r_peak+int(0.4*fs))]
        t_amp = np.max(t_region) if len(t_region)>0 else 0
        features.append(SimpleNamespace(
            RR=round(rr,4), QRS=qrs_duration, HRV=hrv_sdnn, ST=0.12,
            PR=0.16, QT=0.4, P_amp=p_amp, T_amp=t_amp, Entropy=0.5, Variance=0.01
        ))
    return features

# Signal processing & CSV writing (signal + features + diag)
def process_signal(person_name, heartbeat_value, timestamp=None):
    global last_10s, last_5min, last_1h, last_level2

    if timestamp is None: timestamp = datetime.now()
    now = time.mktime(timestamp.timetuple())
    file_path = os.path.join(data_file, f'{person_name}.csv')

    # Fayl mavjud bo'lmasa yaratish
    if not os.path.exists(file_path):
        pd.DataFrame({'Time':[],'Heartbeats':[]}).to_csv(file_path,index=False)

    # Yangi signal qo'shish
    df = pd.read_csv(file_path)
    df_new = pd.DataFrame({'Time':[timestamp],'Heartbeats':[heartbeat_value]})
    df = pd.concat([df, df_new], ignore_index=True)
    df.to_csv(file_path,index=False)

    # Features olish
    features_list = extract_features(file_path)
    if not features_list: return False

    # Global dequesga qo'shish
    last_10s.append((now, features_list[-1]))
    last_5min.append((now, features_list[-1]))
    last_1h.append((now, features_list[-1]))

    # Ortiqcha ma'lumotlarni o'chirish
    last_10s = deque([(t,f) for t,f in last_10s if now-t <= DIAG1_INTERVAL])
    last_5min = deque([(t,f) for t,f in last_5min if now-t <= DIAG2_INTERVAL])
    last_1h = deque([(t,f) for t,f in last_1h if now-t <= DIAG3_INTERVAL])

    # Diagnostika
    diag1 = diagnose_level1(last_10s[-1][1]) if last_10s else "Me'yorda"
    diag2 = diagnose_level2_from_window([f for t,f in last_5min])
    last_level2.append((now, diag2))
    diag3 = diagnose_level3_from_history([d for t,d in last_level2 if now-t <= DIAG3_INTERVAL],
                                         [f for t,f in last_1h])

    # Bitta faylga signal + features + diag yozish
    combined_path = os.path.join(data_file, f'{person_name}_full.csv')

    # Har bir feature uchun dictionary tayyorlash
    rows = []
    for f in features_list:
        row = {
            'Time': timestamp,
            'Heartbeats': heartbeat_value,
            'RR': f.RR, 'QRS': f.QRS, 'HRV': f.HRV, 'ST': f.ST,
            'PR': f.PR, 'QT': f.QT, 'P_amp': f.P_amp, 'T_amp': f.T_amp,
            'Entropy': f.Entropy, 'Variance': f.Variance,
            'Level1': diag1, 'Level2': diag2, 'Level3': diag3
        }
        rows.append(row)

    df_combined = pd.DataFrame(rows)
    if not os.path.exists(combined_path):
        df_combined.to_csv(combined_path,index=False)
    else:
        df_all = pd.read_csv(combined_path)
        df_all = pd.concat([df_all, df_combined], ignore_index=True)
        df_all.to_csv(combined_path,index=False)

    return True
