#!/usr/bin/env python3
import os
import re
import math
import argparse
from collections import defaultdict, deque, namedtuple

import numpy as np
import pandas as pd
from Bio.PDB import PDBParser, MMCIFParser, is_aa
from Bio.PDB.vectors import calc_angle  # for convenience

from scipy.spatial import cKDTree
from numpy.linalg import lstsq, eig, inv, svd
from typing import List, Dict, Tuple

from scipy.interpolate import BSpline
from Bio.PDB import DSSP

from pymol import cmd

import subprocess
import tempfile
from PIL import Image

# ------------------------
# Parameters / thresholds
# ------------------------
# H-bond DSSP energy threshold (kcal/mol)
HBOND_ENERGY_CUTOFF = -0.5

# geometry fallbacks
ON_DIST_MAX = 3.5   # Å: O...N
CA_PAIR_MIN = 3.0   # min Cα–Cα for pairing sanity
CA_PAIR_MAX = 6.0   # max Cα–Cα for pairing sanity

# ladder length minimum to be considered robust
MIN_LADDER_PAIRS = 3

# orientation filter thresholds (cosine)
PARALLEL_COS_MIN = 0.4
ANTIPARALLEL_COS_MAX = -0.4

# sheet planarity tolerance (RMSD of Cα to sheet PCA plane) in Å
PLANARITY_RMSD_MAX = 3.0

# ------------------------
# Small helpers
# ------------------------
def norm(v):
    n = np.linalg.norm(v)
    if n < 1e-12:
        return v * 0.0
    return v / n

def dist(a, b):
    return float(np.linalg.norm(a - b))

def rmsd_to_plane(points):
    # fit plane via PCA, return RMS deviation of points to plane
    if len(points) < 3:
        return 0.0
    pts = np.asarray(points)
    cen = pts.mean(axis=0)
    U, S, Vt = np.linalg.svd(pts - cen, full_matrices=False)
    normal = Vt[2]  # last principal component
    dists = np.abs(np.dot(pts - cen, normal))
    return float(np.sqrt(np.mean(dists * dists)))

# ------------------------
# Hydrogen position estimation
# ------------------------

# Recommended parameters
NH_BOND_LEN = 1.01    # Å, typical N-H bond length (1.00 - 1.02 Å)
NH_ANGLE_DEG = 120.0  # approximate H-N-CA angle (deg)

def _safe_norm(v):
    n = np.linalg.norm(v)
    if n < 1e-12:
        return np.zeros_like(v), 0.0
    return v / n, n

def place_amide_h_bisector(N, CA, C_prev, bond_length=NH_BOND_LEN):
    """
    Place backbone amide hydrogen H at residue N using bisector of vectors to C_prev and CA.
    N, CA, C_prev: numpy arrays (3,)
    Returns H coordinate (3,)
    Notes:
      - Requires C_prev and CA present. If C_prev is None, will fall back to placing H
        in direction opposite to C=O (not ideal).
      - Produces H lying in peptide plane, approximately satisfying H-N-CA ≈ 120°.
    """
    if N is None:
        raise ValueError("N coordinate required")
    if CA is None and C_prev is None:
        raise ValueError("At least one of CA or C_prev coordinates must be provided")

    # Vectors from N toward neighbors
    v_ca, na = _safe_norm(np.asarray(CA) - np.asarray(N)) if CA is not None else (None, 0.0)
    v_cprev, nc = _safe_norm(np.asarray(C_prev) - np.asarray(N)) if C_prev is not None else (None, 0.0)

    if v_ca is None:
        # fallback: point roughly along C_prev->N direction (reverse)
        # H roughly points away from C_prev: so use -normalize(C_prev-N)
        v = -v_cprev
    elif v_cprev is None:
        # fallback: point roughly along -CA (away from CA)
        v = -v_ca
    else:
        # bisector of v_ca and v_cprev
        # H should lie roughly between the two bonds, so take normalized (v_ca + v_cprev)
        sum_vec = (v_ca + v_cprev)
        s_norm = np.linalg.norm(sum_vec)
        if s_norm < 1e-6:
            # nearly collinear opposite vectors (rare): pick vector perpendicular in peptide plane
            # use cross product to get plane normal and rotate slightly
            plane_norm = np.cross(v_cprev, v_ca)
            if np.linalg.norm(plane_norm) < 1e-6:
                # degenerate: fall back to -v_ca
                v = -v_ca
            else:
                # choose vector in plane roughly bisecting
                v = (v_ca - v_cprev)  # arbitrary fallback
                v = v / (np.linalg.norm(v) + 1e-12)
        else:
            v = sum_vec / s_norm
        # H direction should point *away* from peptide backbone; ensure pointing outward
        # We want H such that angle H-N-CA is ~120°, check sign:
        # If dot(v, v_ca) > 0.5 (i.e. too aligned with CA), invert
        if np.dot(v, v_ca) > 0.9:
            v = -v

    H = np.asarray(N) + (v / (np.linalg.norm(v) + 1e-12)) * bond_length
    return H

# Internal-coordinate placement for exact bond length + angle + dihedral
def place_atom_internal(A, B, C, bond_length, bond_angle_deg, dihedral_deg):
    """
    Place atom D attached to C given reference atoms A-B-C and desired:
      - |D - C| = bond_length
      - angle B-C-D = bond_angle_deg (degrees)
      - dihedral A-B-C-D = dihedral_deg (degrees)
    Implementation uses the standard orthonormal basis at C:
      u = normalize(B - C)       (points from C toward B)
      v = normalized component of (A - B) perpendicular to u
      w = u x v
    Then D = C + r * (cos(theta)*u + sin(theta)*(cos(phi)*v + sin(phi)*w))
    NOTE: this is the corrected sign convention (u points C->B).
    """
    A = np.asarray(A); B = np.asarray(B); C = np.asarray(C)
    bond_angle = math.radians(bond_angle_deg)
    dihedral = math.radians(dihedral_deg)

    # u: unit vector from C to B (C->B)
    u, nu = _safe_norm(B - C)
    if nu < 1e-12:
        raise ValueError("B and C are coincident or too close")

    # v: projection of (A - B) orthogonal to u
    v_raw = A - B
    v_perp = v_raw - np.dot(v_raw, u) * u
    nv = np.linalg.norm(v_perp)
    if nv < 1e-12:
        # degenerate: choose an arbitrary perpendicular to u
        if abs(u[0]) < 0.9:
            tmp = np.array([1.0, 0.0, 0.0])
        else:
            tmp = np.array([0.0, 1.0, 0.0])
        v_perp = tmp - np.dot(tmp, u) * u
        nv = np.linalg.norm(v_perp)
    v = v_perp / (nv + 1e-12)

    # w: orthonormal third axis
    w = np.cross(u, v)

    # coordinates in local basis (attached to C)
    x = math.cos(bond_angle)
    y = math.sin(bond_angle) * math.cos(dihedral)
    z = math.sin(bond_angle) * math.sin(dihedral)

    # combine and scale by bond length
    D = C + bond_length * (x * u + y * v + z * w)
    return D


def place_amide_h_internal(N, CA, C_prev, bond_length=NH_BOND_LEN, bond_angle_deg=NH_ANGLE_DEG):
    """
    Place H using internal coordinates anchored on (C_prev - N - CA).
    We'll choose dihedral so H lies in the peptide plane: dihedral = 0 or 180 deg.
    Use 180 deg so H and C_prev are roughly trans across the peptide.
    """
    # Use A = C_prev, B = N, C = CA; want to place H as D
    if C_prev is None or CA is None:
        # fallback to bisector method
        return place_amide_h_bisector(N, CA, C_prev, bond_length=bond_length)

    # set dihedral so H is in peptide plane; choose 180 to put H opposite side
    dihedral_deg = 180.0

    # Use references A' = C_prev, B' = CA, C' = N and place D attached to N (C').
    A_ref = C_prev
    B_ref = CA
    C_ref = N

    # Now bond angle at C' is angle B'-C'-D = CA - N - H
    # dihedral choose 0 (in plane); pick sign so H points outward; use 180 or 0 as needed
    D = place_atom_internal(A_ref, B_ref, C_ref, bond_length, bond_angle_deg, dihedral_deg)
    return D

# ------------------------
# Kabsch-Sander (DSSP) energy function
# ------------------------
def dssp_hbond_energy(O_coord, H_coord, C_coord, N_coord):
    """
    Energy formula from DSSP (Kabsch & Sander).
    E = 0.42 * (1/r_ON + 1/r_CH - 1/r_OH - 1/r_CN)  [kcal/mol]
    Distances in Å. Returns kcal/mol.
    """
    # distances
    r_ON = dist(O_coord, N_coord)
    r_CH = dist(C_coord, H_coord)
    r_OH = dist(O_coord, H_coord)
    r_CN = dist(C_coord, N_coord)
    # guard against zero
    for r in (r_ON, r_CH, r_OH, r_CN):
        if r < 1e-6:
            return 0.0
    E = 27.888 * ((1.0 / r_ON) + (1.0 / r_CH) - (1.0 / r_OH) - (1.0 / r_CN))
    #print(r_ON, r_CH, r_OH, r_CN, E)
    return float(E)

# ------------------------
# PDB parsing utilities
# ------------------------
ResidueKey = namedtuple("ResidueKey", ["chain", "resseq", "icode", "resname"])

def residue_key(res):
    # res is a Biopython Residue
    cid = res.get_parent().id
    resseq = res.id[1]
    icode = res.id[2].strip() if res.id[2].strip() else ""
    return ResidueKey(chain=cid, resseq=resseq, icode=icode, resname=res.resname)

def extract_backbone_atoms(res):
    """
    Return dict with coordinates for N, CA, C, O, H (if present) for a residue.
    H may be missing. C_prev required to estimate H.
    """
    atoms = {}
    for name in ("N", "CA", "C", "O", "H"):
        if name in res:
            atoms[name] = res[name].coord.copy()
        else:
            atoms[name] = None
    return atoms

# ------------------------
# Identify β-strands (simple geometric heuristic)
# ------------------------
def detect_extended_runs(residues, min_len=3):
    """
    Given a list of contiguous residues (Biopython Residue objects),
    returns list of (start_idx, end_idx) indices for candidate strands using CA angles:
      classify residue i as extended if angle between CA_{i-1}->CA_i and CA_i->CA_{i+1} > 140 deg
    """
    ca_coords = []
    idx_map = []
    for i, r in enumerate(residues):
        if "CA" in r:
            ca_coords.append(r["CA"].coord.copy())
            idx_map.append(i)
    ca_coords = np.array(ca_coords)
    n = len(ca_coords)
    extended_mask = np.zeros(n, dtype=bool)
    for i in range(1, n - 1):
        v1 = ca_coords[i] - ca_coords[i - 1]
        v2 = ca_coords[i + 1] - ca_coords[i]
        # angle
        cosang = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-12)
        cosang = max(-1.0, min(1.0, cosang))
        ang = math.degrees(math.acos(cosang))
        if ang > 140.0:
            extended_mask[i] = True
    # group runs
    runs = []
    i = 0
    while i < n:
        if extended_mask[i]:
            j = i
            while j < n and extended_mask[j]:
                j += 1
            if (j - i) >= min_len:
                # map back to original residue indices
                runs.append((idx_map[i], idx_map[j - 1]))
            i = j
        else:
            i += 1
    return runs

# ------------------------
# H-bond detection between residues
# ------------------------
def build_backbone_coords_map(chain_residues):
    """
    Build a mapping from local residue index -> backbone atom coords and neighbor C_prev if available.
    Returns list of dicts with keys: 'N','CA','C','O','H_est' (H explicit or estimated), 'res'
    """
    entries = []
    for i, res in enumerate(chain_residues):
        atoms = extract_backbone_atoms(res)
        # estimate H position if missing using previous residue C if available
        H = atoms.get("H", None)
        C_prev = None
        if i > 0:
            prev = chain_residues[i - 1]
            if "C" in prev:
                C_prev = prev["C"].coord.copy()
        if H is None or H is False:
            if atoms["N"] is not None and atoms["CA"] is not None and res.resname != "PRO":
                try:
                    H_est = place_amide_h_internal(atoms["N"], atoms["CA"], C_prev)
                except Exception:
                    H_est = None
            else:
                H_est = None
        else:
            H_est = H
        entry = {
            "res": res,
            "N": atoms["N"],
            "CA": atoms["CA"],
            "C": atoms["C"],
            "O": atoms["O"],
            "H": H_est
        }
        entries.append(entry)
    return entries


def candidate_hbond_pairs(chain_entries, max_on = 3.5):
    """
    Return list of candidate pairs (i,j) where the O_i - N_j distance is <= max_on.
    We'll compute energy for each candidate later.
    """
    n = len(chain_entries)
    pairs = []
    # naive O(N^2) within chain and across other chains; caller should pass combined lists
    for i in range(n):
        oi = chain_entries[i]["O"]
        if oi is None:
            continue
        for j in range(n):
            if i == j:
                continue
            nj = chain_entries[j]["N"]
            if nj is None:
                continue
            if dist(oi, nj) <= max_on:
                pairs.append((i, j))
    return pairs

# For multi-chain files we will compare across all residues in the model.
def compute_hbond_energies(entries):
    """
    entries: list of backbone entries across the whole model (ordered by chain/residue).
    Returns dict mapping (i_index, j_index) -> energy (kcal/mol)
    Uses estimated H positions when explicit H not present. If required atoms missing, energy not computed.
    """
    energies = {}
    n = len(entries)
    for i in range(n):
        oi = entries[i]["O"]
        ci = entries[i]["C"]
        if oi is None or ci is None:
            continue
        for j in range(n):
            if i == j:
                continue
            nj = entries[j]["N"]
            hj = entries[j]["H"]
            cj = entries[j]["C"]  # not required but keep
            if nj is None or hj is None:
                continue
            # compute energy using DSSP formula expecting (O_i, H_j, C_i, N_j)
            try:
                E = dssp_hbond_energy(oi, hj, ci, nj)
                energies[(i,j)] = E
            except Exception:
                # ignore numeric issues
                pass
    return energies

# ------------------------
# Bridges, ladders, sheets
# ------------------------
def detect_bridges(entries, energies, energy_cutoff=HBOND_ENERGY_CUTOFF):
    """
    Using energies dict (i,j)->E, detect beta-bridges (parallel/antiparallel) following DSSP-like rules:
      - Antiparallel bridge between residue i and j exists if E(i->j) < cutoff and E(j->i) < cutoff (reciprocal)
      - Parallel bridge (shifted) is more complicated but can be approximated:
         E(i->j+1) and E(i+1->j) both < cutoff (parallel shifted)
    Returns list of bridge records: dict with keys:
      ('i','j','type','E_ij','E_ji') where i,j are integer indices into entries list
    """
    bridges = []
    antipar_bridges = set()
    # antiparallel
    for (i,j), Eij in list(energies.items()):
        if Eij >= energy_cutoff:
            continue
        EjI = energies.get((j,i), None)
        if EjI is not None and EjI < energy_cutoff:
            bridges.append({"i":i, "j":j, "type":"antiparallel", "E_ij":Eij, "E_ji":EjI})
            antipar_bridges |= {i, j}
    # parallel (shifted)
    # check (i->j+1) and (i+1->j)
    n = len(entries)
    for i in range(n-1):
        for j in range(n-1): 
            if abs(i-j) < 5:
                continue # To avoid helices and weird hinges
            E1 = energies.get((i-1, j), 1e6)
            E2 = energies.get((j, i+1), 1e6)
            if E1 < energy_cutoff and E2 < energy_cutoff and not ({i, j} & antipar_bridges): # An antiparallel bridge already fills all H-bonds
                bridges.append({"i":i, "j":j, "type":"parallel_shifted", "E_i_jplus1":E1, "E_ip1_j":E2})
                
    # Remove duplicates (antiparallel detection could add mirrored pairs); we'll canonicalize later
    return bridges

def canonical_bridge_key(b):
    # canonical ordering for strand detection
    a = b["i"]; bidx = b["j"]
    if a <= bidx:
        return (a, bidx, b["type"])
    else:
        return (bidx, a, b["type"])

def aggregate_strands_from_bridges(bridges, dssp_strands):
    """
    Group bridges into ladders: consecutive residue-pair runs between two strands.
    Simplified approach:
      - For each pair of residues (i0,j0) that form a bridge, attempt to extend consecutive pairs
        i0+1,j0+1 ... (for parallel/antiparallel we may need to invert indexing depending on orientation)
    Returns list of ladders where each ladder is dict:
      {'strand_a_indices': (min_res_idx, max_res_idx), 'strand_b_indices': (...),
       'pairs': [(i,j),...], 'type': 'antiparallel'/'parallel_shifted'}
    """
    # build set of bridge coordinates
    bridge_d = {}
    bridgeres_d = defaultdict(list)
    n_min, n_max = 1000000, 0
    for b in bridges:
        #print(b['i'], b['j'], b['type'])
        bridge_d[(b['i'], b['j'])] = b['type']
        bridgeres_d[b['i']].append((b['j'], b['type']))
        bridgeres_d[b['j']].append((b['i'], b['type']))
        n_min = min(n_min, b['i'], b['j'])
        n_max = max(n_max, b['i'], b['j'])

    initial_bridge_idx = {(i,j) for i, j in bridge_d}
    for i, j in initial_bridge_idx:
        #print("ORIGBRIDGE", i, j, bridge_d[(i, j)])
        # Check antiparallel completion
        if (i+2, j-2) in bridge_d and bridge_d[(i+2, j-2)] == bridge_d[(i, j)]:
            #print("ADDBRIDGE", i+1, j-1, bridge_d[(i, j)])
            bridgeres_d[i+1].append((j-1, bridge_d[(i, j)]))
            bridgeres_d[j-1].append((i+1, bridge_d[(i, j)]))
        if (i-2, j+2) in bridge_d and bridge_d[(i-2, j+2)] == bridge_d[(i, j)]:
            #print("ADDBRIDGE", i-1, j+1, bridge_d[(i, j)])
            bridgeres_d[i-1].append((j+1, bridge_d[(i, j)]))
            bridgeres_d[j+1].append((i-1, bridge_d[(i, j)]))

        # Check parallel completion
        if ((j+2, i+2) in bridge_d and bridge_d[(j+2, i+2)] == bridge_d[(i, j)])\
           or ((i+2, j+2) in bridge_d and bridge_d[(i+2, j+2)] == bridge_d[(i, j)]):
            #print("ADDBRIDGE", i+1, j+1, bridge_d[(i, j)])
            bridgeres_d[i+1].append((j+1, bridge_d[(i, j)]))
            bridgeres_d[j+1].append((i+1, bridge_d[(i, j)]))
        if ((j-2, i-2) in bridge_d and bridge_d[(j-2, i-2)] == bridge_d[(i, j)])\
           or ((i-2, j-2) in bridge_d and bridge_d[(i-2, j-2)] == bridge_d[(i, j)]):
            #print("ADDBRIDGE", i-1, j-1, bridge_d[(i, j)])
            bridgeres_d[i-1].append((j-1, bridge_d[(i, j)]))
            bridgeres_d[j-1].append((i-1, bridge_d[(i, j)]))

    ## CHECK
    #for i in range(n_min, n_max+1):
    #    if i in bridgeres_d:
    #        print(i, bridgeres_d[i])
    ####

    types = set()
    strands = []
    strand = []
    connections = set()
    for i in range(n_min, n_max+1):
        if i in bridgeres_d:
            types |= {br[1] for br in bridgeres_d[i]}
            strand.append(i)
            connections |= {(i, m, t) for m, t in bridgeres_d[i]}
        else:
            if strand and len(strand)>1:
                strands.append({
                    "range" : strand,
                    "connections": connections,
                    "type(s)": types,
                    "length": len(strand)
                })
            types = set()
            strand = []
            connections = set()

    strands = sorted(strands, key= lambda x:x['range'][0])

    ## CHECK
    #for s in strands:
    #    print(s)
    ####

    if not strands:
        return [], {}, []

    new_strands = [strands[0]]
    for i in range(1,len(strands)):
        if strands[i]['range'][0] - strands[i-1]['range'][-1] < 3:
            strandrange = [i for i in range(strands[i-1]['range'][0], strands[i]['range'][-1]+1)]

            maxinters = []
            for dssp_strand in dssp_strands:
                inters = sorted(list(set(strandrange) & set(dssp_strand["range"])))
                if len(inters) > len(maxinters):
                    maxinters = inters
            strandrange = maxinters
            if len(strandrange)>1:
                connections = {x for x in strands[i-1]["connections"] | strands[i]["connections"] if x[0] in strandrange}
                new_strands[-1] = {
                        "range" : strandrange,
                        "connections": connections,
                        "type(s)": strands[i-1]["type(s)"] | strands[i]["type(s)"],
                        "length": len(strandrange)
                }
            else:
                new_strands.append(strands[i])
        else:
            new_strands.append(strands[i])
    strands = new_strands

    ## CHECK
    #for s in strands:
    #    print(s)
    ####

    orientations = {}
    for s1 in strands:
        for _, i, t in s1["connections"]:
            for s2 in strands:
                if s1 == s2:
                    continue
                if i in s2["range"]:
                    if ((s1["range"][0], s2["range"][0])) not in orientations:
                        orientations[(s1["range"][0], s2["range"][0])] = t
                    if ((s2["range"][0], s1["range"][0])) not in orientations:
                        orientations[(s2["range"][0], s1["range"][0])] = t
                    break

    missed_dssp_strands = []
    for dssp_strand in dssp_strands:
          #print(dssp_strand)
          if len(dssp_strand["range"]) > 6:
              found = False
              for strand in strands:
                  inters = sorted(list(set(strand["range"]) & set(dssp_strand["range"])))
                  if inters:
                      found = True
                      #print("INTERS", strand, inters)
              if not found:
                  missed_dssp_strands.append(dssp_strand)
    #print("MISSED DSSP STRANDS", missed_dssp_strands)
    return strands, orientations, missed_dssp_strands


def build_sheet_graph_from_strands(strands):
    """
    Build graph where nodes are strand ranges (tuple) and edges connect strands that share contacts.
    Returns adjacency dict and mapping of strand node -> list of ladders
    """
    strand_nodes = {}
    strand_map = defaultdict(list)
    adjacency = defaultdict(set)
    for strand in strands:
        for _, jres, jtype in strand['connections']:
            for s2 in strands:
                if s2 == strand:
                    continue
                if jres in s2['range']:
                    adjacency[strand['range'][0]].add(s2['range'][0])
                    adjacency[s2['range'][0]].add(strand['range'][0])
    return adjacency 


def build_sheet_graph_from_ladders(ladders):
    """
    Build graph where nodes are strand ranges (tuple) and edges connect strands that share ladder.
    Returns adjacency dict and mapping of strand node -> list of ladders
    """
    strand_nodes = {}
    node_id = 0
    def node_key(range_tuple):
        return (range_tuple[0], range_tuple[1])
    ladder_map = defaultdict(list)
    adjacency = defaultdict(set)
    for lad in ladders:
        a = node_key(lad['strand_a_range'])
        b = node_key(lad['strand_b_range'])
        # add nodes
        if a not in strand_nodes:
            strand_nodes[a] = node_id; node_id += 1
        if b not in strand_nodes:
            strand_nodes[b] = node_id; node_id += 1
        adjacency[a].add(b)
        adjacency[b].add(a)
        ladder_map[a].append(lad)
        ladder_map[b].append(lad)
    return adjacency, ladder_map, list(strand_nodes.keys())

def connected_components(adjacency):
    """
    adjacency: dict node -> set(neighbors)
    returns list of components as lists of nodes
    """
    seen = set()
    comps = []
    for node in adjacency.keys():
        if node in seen:
            continue
        comp = []
        q = deque([node])
        seen.add(node)
        while q:
            v = q.popleft()
            comp.append(v)
            for nb in adjacency[v]:
                if nb not in seen:
                    seen.add(nb)
                    q.append(nb)
        comps.append(comp)
    return comps

# ------------------------
# Orientation & planarity filters
# ------------------------
def estimate_local_max_curvature(points, k_neighbors=50, min_neighbors=10, verbose=False):
    """
    Estimate max(|k1|, |k2|) at each point in a 3D point-cloud surface.

    Parameters
    ----------
    points : (N,3) ndarray
      3D coordinates of points sampled from a 2D surface.
    k_neighbors : int
      Number of nearest neighbors to use for local fitting (including the center).
    min_neighbors : int
      Minimum neighbors required to attempt fit; otherwise curvature = np.nan.
    verbose : bool
      Print progress occasionally.

    Returns
    -------
    max_curv : (N,) ndarray
      Maximum absolute principal curvature at each point (1/Å or 1/unit length).
      Nan where estimation failed.
    k1_all, k2_all : (N,) ndarrays
      Signed principal curvatures (k1 >= k2 convention not enforced).
    directions : (N,2,3) ndarray
      Principal direction unit vectors in 3D (first column = direction for k1, second for k2).
      nans where unavailable.
    """

    pts = np.asarray(points, dtype=float)
    N = len(pts)
    tree = cKDTree(pts)
    # Query k nearest including the point itself
    n_queried = max(min(N, k_neighbors), min_neighbors)
    dists, idxs = tree.query(pts, k=n_queried)#, n_jobs=-1)

    max_curv = np.full(N, np.nan, dtype=float)
    k1_all = np.full(N, np.nan, dtype=float)
    k2_all = np.full(N, np.nan, dtype=float)
    directions = np.full((N, 2, 3), np.nan, dtype=float)

    for i in range(N):
        neigh_idx = np.atleast_1d(idxs[i])
        # ensure we have a reasonable number of neighbors
        if len(neigh_idx) < min_neighbors:
            continue
        # center point
        p0 = pts[i]

        # build neighbor coordinates centered at p0
        neigh_pts = pts[neigh_idx] - p0  # shape (k,3)

        # PCA to get tangent plane: use SVD on covariance (or on neigh_pts)
        # Note: SVD on the neighbor point matrix yields right singular vectors as principal axes
        # Use covariance approach for clarity
        C = np.cov(neigh_pts.T)
        try:
            U, Svals, Vt = svd(C)
        except Exception:
            # fallback
            try:
                _, Svals, Vt = svd(neigh_pts - neigh_pts.mean(axis=0))
            except Exception:
                continue
        # Vt rows are principal directions; smallest-variance direction is normal
        normal = Vt[-1, :]
        normal /= np.linalg.norm(normal) + 1e-16
        # Build orthonormal tangent axes u,v
        u = Vt[0, :]
        u /= np.linalg.norm(u) + 1e-16
        v = np.cross(normal, u)
        v /= np.linalg.norm(v) + 1e-16

        # coordinates in local frame: x = dot(neigh_pts, u), y = dot(..., v), z = dot(..., normal)
        x = neigh_pts.dot(u)
        y = neigh_pts.dot(v)
        z = neigh_pts.dot(normal)

        # We will fit z = a*x^2 + b*x*y + c*y^2 + d*x + e*y + f0
        # design matrix:
        A = np.vstack([x**2, x*y, y**2, x, y, np.ones_like(x)]).T
        # solve least squares
        try:
            coeffs, *_ = lstsq(A, z, rcond=None)
        except Exception:
            continue
        a, b, c, d, e, f0 = coeffs

        # derivatives at the origin (we centered at p0 so origin corresponds to center point)
        fx = d
        fy = e
        fxx = 2.0 * a
        fxy = b
        fyy = 2.0 * c

        # First fundamental form (2x2)
        E = 1.0 + fx*fx
        F = fx * fy
        G = 1.0 + fy*fy

        # Second fundamental form coefficients:
        # Normalizing factor for exact formula:
        W = np.sqrt(1.0 + fx*fx + fy*fy)
        L = fxx / W
        M = fxy / W
        Ncoef = fyy / W

        I = np.array([[E, F], [F, G]])
        II = np.array([[L, M], [M, Ncoef]])

        # Solve generalized eigenvalue problem II v = k I v => k = eigenvalues of inv(I) @ II
        # Numerically more stable: solve standard eigenproblem for symmetric matrix
        try:
            # invert I (2x2 closed form might be faster but inv is fine)
            I_inv = inv(I)
            S = I_inv.dot(II)
            vals, vecs = eig(S)
            # vals may be complex due to numerics; take real parts if imag small
            if np.max(np.abs(np.imag(vals))) > 1e-6:
                # unstable result; skip
                continue
            vals = np.real(vals)
            # principal directions in tangent coords are columns of vecs (2,)
            # convert to 3D directions: dir3d = vec[0]*u + vec[1]*v
            # store k1,k2 and directions
            # order so k1 has larger absolute value
            order = np.argsort(-np.abs(vals))
            k1 = vals[order[0]]
            k2 = vals[order[1]]
            v1_2 = vecs[:, order[0]]
            v2_2 = vecs[:, order[1]]
            dir1_3 = v1_2[0]*u + v1_2[1]*v
            dir2_3 = v2_2[0]*u + v2_2[1]*v
            # normalize directions
            dir1_3 /= (np.linalg.norm(dir1_3) + 1e-16)
            dir2_3 /= (np.linalg.norm(dir2_3) + 1e-16)

            max_curv[i] = max(abs(k1), abs(k2))
            k1_all[i] = k1
            k2_all[i] = k2
            directions[i, 0, :] = dir1_3
            directions[i, 1, :] = dir2_3
        except Exception:
            continue

        if verbose and (i % 500 == 0):
            print(f"Processed {i}/{N}")

    return max_curv#, k1_all, k2_all, directions


def strand_axis_vector(entries, start_idx, end_idx):
    """
    Compute axis vector from N-term CA to C-term CA for the strand slice
    entries is the global entries list; indices are positions in that list
    """
    pts = []
    for k in range(start_idx, end_idx + 1):
        ca = entries[k]["CA"]
        if ca is None:
            continue
        pts.append(ca)
    if len(pts) < 2:
        return np.array([0.0,0.0,0.0])
    return norm(np.asarray(pts[-1]) - np.asarray(pts[0]))


def mean_ca_distance_between_pairs(entries, pairs):
    dists = []
    for (i,j) in pairs:
        ca_i = entries[i]["CA"]; ca_j = entries[j]["CA"]
        if ca_i is None or ca_j is None:
            continue
        dists.append(dist(ca_i, ca_j))
    if not dists:
        return float('inf')
    return float(np.mean(dists))

# ------------------------
# Main processing for a file
# ------------------------
def backbone_model(chain_res_lists):

    # Build a global entries list with mapping chain/seq idx -> global index
    all_entries = {}
    for chain_id, residues in chain_res_lists:
        local_entries = build_backbone_coords_map(residues)
        all_entries[chain_id] = local_entries
        #for lid, ent in enumerate(local_entries):
        #    entries.append(ent)
    return all_entries  # dict chain_id -> entries


def find_beta_sheets(entries, DSSP_strands):
    # compute candidate hbonds by O-N distance and then energies
    # We'll use full pairwise scan across entries but skip near-sequence neighbors
    energies = compute_hbond_energies(entries)

    # identify candidate bridges
    bridges = detect_bridges(entries, energies)

    # Now aggregate into strands
    strands, orientations, missed_dssp_strands = aggregate_strands_from_bridges(bridges, DSSP_strands)

    # Find connected components
    # TODO upgrade to spectral clustering for a better separation of beta sheets
    adjacency = build_sheet_graph_from_strands(strands)
    comps = connected_components(adjacency)

    # Define beta sheets
    beta_sheets = []
    beta_sheet = []
    for comp in comps:
        beta_sheet = []
        for r0 in comp:
            for strand in strands:
                if strand['range'][0] == r0:
                    beta_sheet.append(strand['range'])
        beta_sheets.append(beta_sheet)

    return bridges, strands, orientations, beta_sheets, missed_dssp_strands

MIN_HELIX_LENGTH = 6


def detect_helices_by_geometry(entries):
    # Very simple heuristic: continuous runs where i->i+4 CA distance ~5.4 Å
    # and length >= MIN_HELIX_LENGTH
    helices = []
    if len(entries) < MIN_HELIX_LENGTH:
        return helices

    # compute i->i+4 distances
    n = len(entries)
    helix_mask = np.zeros(n, dtype=bool)
    for i in range(n - 4):
        d = np.linalg.norm(entries[i]["CA"] - entries[i+4]["CA"])
        if d < 6.0:  # loose threshold
            helix_mask[i:i+4] = True

    # group masks into runs
    i = 0
    while i < n:
        if helix_mask[i]:
            j = i
            while j < n and helix_mask[j]:
                j += 1
            # residues i..j-1
            if (j - i) >= MIN_HELIX_LENGTH:
                helices.append({
                    "range" : [x for x in range(i,j)],
                    "turn": 4,   # 3 and 5 not implemented in this function !
                    "length": j - i
                })
            i = j
        else:
            i += 1
    return helices


def find_runs_regex(s):
    return {
        "G": [[x for x in range(m.start(), m.end()+1)] for m in re.finditer(r"G{3,}", s)],
        "H": [[x for x in range(m.start(), m.end()+1)] for m in re.finditer(r"H{4,}", s)],
        "I": [[x for x in range(m.start(), m.end()+1)] for m in re.finditer(r"I{5,}", s)]
    }


def detect_helices_strands_per_chain_with_DSSP(model, fn, chain_res_lists, dssp_path = os.environ["MKDSSP"]):
    dssp = DSSP(model, fn, dssp=dssp_path)
    dssp_dict = dict(dssp)

    # Parse b-factor (pLDDT) and DSSP
    all_helices = {}
    all_strands = {}
    for chain_id, residues in chain_res_lists:
        chain_ss = []
        for i, residue in enumerate(residues):
            try:
                ss = dssp_dict.get((residue.get_full_id()[2], residue.id))[2]
                chain_ss.append(ss)
            except:
                print((residue.get_full_id()[2], residue.id), "NOT FOUND")
    
        ss_string = "".join(chain_ss)
        G_list = [(m.start(), 3, [x for x in range(m.start(), m.end()+1)]) for m in re.finditer(r"G{3,}", ss_string)]
        H_list = [(m.start(), 4, [x for x in range(m.start(), m.end()+1)]) for m in re.finditer(r"H{4,}", ss_string)]
        I_list = [(m.start(), 5, [x for x in range(m.start(), m.end()+1)]) for m in re.finditer(r"I{5,}", ss_string)]
        all_list = G_list + H_list + I_list
        all_list = sorted(all_list, key = lambda x:x[0])

        all_helices[chain_id] = []
        for _, turn, hrange in all_list:
            all_helices[chain_id].append({
                "range" : hrange,
                "turn": turn,
                "length": len(hrange)
            })

        E_list = [(m.start(), 3, [x for x in range(m.start(), m.end()+1)]) for m in re.finditer(r"E{3,}", ss_string)]

        all_strands[chain_id] = []
        for _, turn, srange in E_list:
            all_strands[chain_id].append({
                "range" : srange,
                "length": len(srange)
            })

    #print("ALL HELICES", all_helices)
    #print("ALL STRANDS", all_strands)

    return all_helices, all_strands
 

def find_helices(entries):
    return detect_helices_by_geometry(entries)


def runs_above_threshold(arr, threshold, N):
    m = arr <= threshold
    diffs = np.diff(np.concatenate(([0], m.astype(int), [0])))
    starts = np.where(diffs == 1)[0]
    ends   = np.where(diffs == -1)[0]

    # return only runs of length ≥ N
    results = [(s, e) for s, e in zip(starts, ends) if e - s >= N]
    return results


THR_HELIX_FAIL_RUN = 8

def qc_helices(ca_coords: Dict[int,np.ndarray],
              helices: List[List[int]],
              turns: List[int],
              thresholds: Dict = None):

    # default thresholds
    default_th = {
        "curv_rad_local_fail": 30.0,
        "curv_rad_local_warn": 40.0
    }
    if thresholds is not None:
        default_th.update(thresholds)
    th = default_th

    results = {"helices": []}

    for h_idx, helix in enumerate(helices):
        if len(helix) < 4*turns[h_idx]:
            continue
        helix_res, curvature_radii = local_curvature_radii_for_strand(helix, ca_coords, turns[h_idx], turns[h_idx])

        # basic stats
        stats = {
            "h_idx": h_idx,
            "length": len(helix),
            "curvature_radii": curvature_radii   # per-res aligned to strand_res
        }
        # flags
        flags = []
        # If the curvature goes along with a 0 deg torsion (that preserves H-bonds), then it is a beta-helix and it's ok
        if runs_above_threshold(curvature_radii, th["curv_rad_local_fail"], THR_HELIX_FAIL_RUN):
            flags.append("bend_local_fail")
        elif np.any((curvature_radii <= th["curv_rad_local_warn"])):
            flags.append("bend_local_warn")
        results["helices"].append({"idx": h_idx, "res": helix, "stats": stats, "flags": flags})

    return results


def helix_axis_bend_metrics(ca_coords, turn_len=4,
                            min_sustained_runs=3,
                            radius_threshold=8.0):
    """
    Returns:
      radii: array of local osculating radii (Å)
      minR, medianR, max_sustained_run (count of consecutive triplets with R < threshold)
    Notes:
      - Use turn_len=4 for alpha helices (approx one turn).
      - radius_threshold is recommended ~8 Å to flag strong bending; tune as needed.
    """
    geoc = compute_turn_geocenters(ca_coords, turn_len=turn_len)
    if len(geoc) < 3:
        return np.array([]), float('inf'), float('inf'), 0
    radii = local_osculating_radii_from_geocenters(geoc, triplet_step=1)
    # numeric sanity
    radii = np.where(np.isnan(radii), np.inf, radii)
    minR = float(np.nanmin(radii)) if radii.size else float('inf')
    medianR = float(np.nanmedian(radii)) if radii.size else float('inf')
    small_mask = radii < radius_threshold
    # longest consecutive True run
    if small_mask.size == 0:
        max_run = 0
    else:
        # compute lengths of True runs
        runs = np.split(small_mask, np.where(~small_mask)[0]+1)
        max_run = max((int(np.sum(r)) for r in runs), default=0)
    return radii, minR, medianR, max_run



def dihedral_deg(p0, p1, p2, p3):
    """Compute dihedral angle (degrees) for 4 points."""
    p0 = np.asarray(p0); p1 = np.asarray(p1); p2 = np.asarray(p2); p3 = np.asarray(p3)
    b0 = p1 - p0
    b1 = p2 - p1
    b2 = p3 - p2

     # normal vectors to the planes
    n1 = np.cross(b0, b1)
    n2 = np.cross(b1, b2)

    # normalize normals
    n1 /= np.linalg.norm(n1)
    n2 /= np.linalg.norm(n2)

    # unit vector along b2
    b2u = b2 / np.linalg.norm(b2)

    # compute angle
    x = np.dot(n1, n2)
    y = np.dot(np.cross(n1, n2), b2u)
    angle = np.degrees(np.arctan2(y, x))

    return angle

def angle_between(a, b):
    a = np.asarray(a); b = np.asarray(b)
    na = np.linalg.norm(a); nb = np.linalg.norm(b)
    if na < 1e-12 or nb < 1e-12:
        return 0.0
    cosang = np.dot(a,b) / (na*nb)
    cosang = max(-1.0, min(1.0, cosang))
    return math.degrees(math.acos(cosang))

# -----------------------------
# Strand-level local metrics
# -----------------------------
def catmull_rom_derivatives(P0, P1, P2, P3, t=0.0):
    """Return C'(t) and C''(t) for the Catmull-Rom spline segment P1→P2."""
    P0, P1, P2, P3 = map(np.asarray, (P0, P1, P2, P3))

    # Coefficients
    A = -P0 + P2
    B = 2*P0 - 5*P1 + 4*P2 - P3
    C = -P0 + 3*P1 - 3*P2 + P3

    # First derivative
    Cp = 0.5 * (A + 2*B*t + 3*C*t*t)

    # Second derivative
    Cpp = 0.5 * (2*B + 6*C*t)

    return Cp, Cpp


def fit_bspline(ca, degree=3):
    """
    Fit a uniform open B-spline of given degree through points ca (Nx3).
    Returns BSpline for C(t), C'(t), C''(t).
    """
    ca = np.asarray(ca)
    n = len(ca)

    # number of control points = number of data points (approximation)
    control_points = ca.copy()

    # open uniform knot vector
    k = degree
    m = n + k + 1
    knots = np.zeros(m)
    knots[k:m-k] = np.linspace(0, 1, m-2*k)
    knots[m-k:] = 1

    # Build spline
    C = BSpline(knots, control_points, k)
    Cp = C.derivative(1)
    Cpp = C.derivative(2)

    return C, Cp, Cpp


def curvature_radius(Cp, Cpp):
    """Return curvature radius R = 1/kappa."""
    cross = np.cross(Cp, Cpp)
    denom = np.linalg.norm(Cp)**3
    if denom < 1e-12:
        return np.inf  # zero curvature (straight)
    kappa = np.linalg.norm(cross) / denom
    return np.inf if kappa < 1e-12 else 1.0 / kappa


def backbone_curvature_catmull(catoms):
    """
    catoms: Nx3 array of Cα coordinates
    returns: curvature radius at each interior point
    """
    catoms = np.asarray(catoms)
    n = len(catoms)
    R = np.full(n, np.nan)

    for i in range(1, n-2):
        P0, P1, P2, P3 = catoms[i-1:i+3]
        Cp, Cpp = catmull_rom_derivatives(P0, P1, P2, P3, t=0.0)
        R[i] = curvature_radius(Cp, Cpp)

    return R


def sample_spline(C, n=2000):
    """Return t-values and 3D samples of the spline."""
    t_vals = np.linspace(0, 1, n)
    pts = C(t_vals)
    return t_vals, pts


def curvature_at_ca(ca, C, Cp, Cpp):
    """
    Compute curvature radius at each Cα by projecting onto the B-spline.
    """
    t_vals, pts = sample_spline(C)
    R = []

    for p in ca:
        # find nearest spline sample
        idx = np.argmin(np.linalg.norm(pts - p, axis=1))
        t = t_vals[idx]
        R.append(curvature_radius(Cp(t), Cpp(t)))

    return np.array(R)


def find_geocenters(coords, window_size=3):
    coords = np.asarray(coords)
    centers = []
    for i in range(0, len(coords) - window_size + 1):
        centers.append(coords[i:i+window_size].mean(axis=0))
    centers = np.vstack(centers)
    return centers


def plot_geocenters(all_coords, title=""):
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    fig = plt.figure(figsize=(6,6))
    ax = fig.add_subplot(111, projection='3d')
   
    for coords in all_coords:     
        coords = np.array(coords)
        ax.plot(coords[:,0], coords[:,1], coords[:,2])
    
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(title)
    
    plt.show()


def local_curvature_radii_for_strand(strand_resids: List[int], ca_coords: Dict[int, np.ndarray], window_size: int, center_subsampling: int):
    """
    Compute per-residue local curvature radii.
    Returns:
      indices, curvature_radii  (lists aligned to indices; radius is NaN when it could not be calculated)
    """
    n = len(strand_resids)
    angles = np.full(n, np.nan)

    coords = []
    # padding
    coords.append(np.asarray(ca_coords.get(strand_resids[0]-1, np.array([np.nan,np.nan,np.nan]))))
    for r in strand_resids:
        coords.append(np.asarray(ca_coords.get(r, np.array([np.nan,np.nan,np.nan]))))
    #padding
    coords.append(np.asarray(ca_coords.get(strand_resids[-1]+1, np.array([np.nan,np.nan,np.nan]))))
    if len(strand_resids) == 3:
        coords.append(np.asarray(ca_coords.get(strand_resids[-1]+2, np.array([np.nan,np.nan,np.nan]))))
    coords = np.vstack(coords)

    # geocenters of windows of length 3 (or 4) -> then angle between geocenter segment vectors
    if n < window_size:
        return strand_resids, angles
    centers = find_geocenters(coords, window_size=window_size)


    if centers[::center_subsampling].shape[0] > 3:
        C, Cp, Cpp = fit_bspline(centers[::center_subsampling])
        R = curvature_at_ca(centers, C, Cp, Cpp)
    else:
        R = np.full(n, np.nan)

    return strand_resids, R[:len(strand_resids)]


def local_torsion_for_strand(strand_resids: List[int], ca_coords: Dict[int, np.ndarray], smooth_window=1):
    """
    Compute per-position CA-based torsion as dihedral(CA_{i-1}, CA_i, CA_{i+1}, CA_{i+2})
    Returns lists aligned to strand_resids with NaN where insufficient neighbors.
    """
    n = len(strand_resids)
    tors = np.full(n, np.nan)
    coords = [np.asarray(ca_coords.get(r, np.array([np.nan,np.nan,np.nan]))) for r in strand_resids]
    for i in range(0, n-3):
        p0 = coords[i]; p1 = coords[i+1]; p2 = coords[i+2]; p3 = coords[i+3]
        tors[i+1] = dihedral_deg(p0,p1,p2,p3)  # assign dihedral centered near i+1 (common choice)
    return strand_resids, tors

# -----------------------------
# Osculating radius via geocenters (strand curvature)
# -----------------------------
def turn_geocenters_for_strand(strand_resids: List[int], ca_coords: Dict[int, np.ndarray], turn_len=3):
    coords = [np.asarray(ca_coords[r]) for r in strand_resids]
    n = len(coords)
    if n < turn_len:
        return [], []
    centers = []
    centers_idx = []
    for i in range(0, n - turn_len + 1):
        window = coords[i:i+turn_len]
        centers.append(np.mean(window, axis=0))
        centers_idx.append(i + turn_len//2)
    return centers_idx, np.vstack(centers)

def circumradius_of_triangle(A,B,C):
    A = np.asarray(A); B = np.asarray(B); C = np.asarray(C)
    a = np.linalg.norm(B - C)
    b = np.linalg.norm(C - A)
    c = np.linalg.norm(A - B)
    area = 0.5 * np.linalg.norm(np.cross(B - A, C - A))
    if area < 1e-8:
        return np.inf
    R = (a * b * c) / (4.0 * area)
    return R

def local_osculating_radii_from_centers(centers: np.ndarray):
    m = len(centers)
    if m < 3:
        return np.array([])
    radii = []
    for k in range(0, m-2):
        A = centers[k]; B = centers[k+1]; C = centers[k+2]
        R = circumradius_of_triangle(A,B,C)
        radii.append(R)
    return np.array(radii)

# -----------------------------
# Simple CA-based per-residue rise/spacing
# -----------------------------
def ca_spacing_for_strand(strand_resids: List[int], ca_coords: Dict[int, np.ndarray]):
    coords = [np.asarray(ca_coords[r]) for r in strand_resids]
    n = len(coords)
    dists = np.full(n-1, np.nan)
    for i in range(n-1):
        dists[i] = np.linalg.norm(coords[i+1] - coords[i])
    return dists

# -----------------------------
# Inter-strand metrics (pairing)
# -----------------------------
def pairing_indices_for_pair(strandA: List[int], strandB: List[int], orientation='antiparallel'):
    """
    Make simple pairing mapping between two strands assuming best index-aligned overlap.
    For antiparallel: pair A[i] with B[lenB-1-i]
    For parallel: pair A[i] with B[i]
    Alignments longer/shorter handled by min length.
    """
    la = len(strandA); lb = len(strandB)
    L = min(la, lb)
    pairs = []
    if L < 2:
        return pairs
    if orientation == 'antiparallel':
        for k in range(L):
            pairs.append((strandA[k], strandB[lb-1-k]))
    else:
        for k in range(L):
            pairs.append((strandA[k], strandB[k]))
    return pairs

def pairwise_ca_distances(pairs: List[Tuple[int,int]], ca_coords: Dict[int, np.ndarray]):
    d = []
    for (i,j) in pairs:
        ci = ca_coords.get(i); cj = ca_coords.get(j)
        if ci is None or cj is None:
            d.append(np.nan)
        else:
            d.append(np.linalg.norm(ci - cj))
    return np.array(d)

def pairwise_twist_dihedrals(pairs: List[Tuple[int,int]], ca_coords: Dict[int, np.ndarray], orientation='antiparallel'):
    """
    Compute local twist dihedral per matched residue pair:
      dihedral(CA_A[i], CA_A[i+1], CA_B[j], CA_B[j+1]) for parallel
      dihedral(CA_A[i], CA_A[i+1], CA_B[j], CA_B[j-1]) for antiparallel (indexing accordingly)
    Returns array of dihedrals (deg) aligned with pair index (length-1).
    """
    twists = []
    for idx in range(len(pairs)-1):
        a_i, b_i = pairs[idx]
        a_next, b_next = pairs[idx+1]
        # for antiparallel, the b indices go reversed, so b_next may be previous index in original strand
        try:
            p0 = ca_coords[a_i]
            p1 = ca_coords[a_next]
            p2 = ca_coords[b_i]
            p3 = ca_coords[b_next]
            ang = dihedral_deg(p0,p1,p2,p3)
        except Exception:
            ang = np.nan
        twists.append(ang)
    return np.array(twists)

# -----------------------------
# Top-level per-sheet QC aggregator
# -----------------------------
def qc_sheets(ca_coords: Dict[int,np.ndarray],
              strands: List[List[int]],
              sheets: List[List[int]],
              orientation_map: Dict[Tuple[int,int], str] = None,
              thresholds: Dict = None):
    """
    ca_coords: dict residue_index -> np.array([x,y,z])
    strands: list of strands (list of residue indices)
    sheets: list of sheets (list of strand indices)
    orientation_map: optional dict mapping (strand_idx_a, strand_idx_b) -> 'parallel'/'antiparallel'
      If missing, antiparallel is assumed for pairwise pairing by default.
    thresholds: optional dict to override defaults
    Returns:
      result dict with per-strand and per-sheet metrics & flags
    """
    # default thresholds
    default_th = {
        "curv_rad_local_fail": 5.0,
        "curv_rad_local_warn": 7.0,
        "torsion_deg_warn": 25.0,
        "torsion_deg_fail": 40.0,
        "ca_spacing_min": 3.0,
        "ca_spacing_max": 4.0,
        "pair_ca_dist_warn": 5.5,
        "pair_ca_dist_fail": 6.5,
        "mean_pair_twist_warn": 25.0,
        "mean_pair_twist_fail": 40.0
    }
    if thresholds is not None:
        default_th.update(thresholds)
    th = default_th

    results = {"strands": [], "sheets": []}

    # per-strand analysis
    for s_idx, strand in enumerate(strands):
        strand_res, curvature_radii = local_curvature_radii_for_strand(strand, ca_coords, 3, 2)
        _, torsions = local_torsion_for_strand(strand, ca_coords, smooth_window=1)
        dists = ca_spacing_for_strand(strand, ca_coords)
        # osculating radius via turn-centers
        centers_idx, centers = turn_geocenters_for_strand(strand, ca_coords, turn_len=3)
        radii = local_osculating_radii_from_centers(centers) if len(centers)>0 else np.array([])
        # basic stats
        stats = {
            "strand_idx": s_idx,
            "length": len(strand),
            "curvature_radii": curvature_radii,   # per-res aligned to strand_res
            "torsions": torsions,
            "ca_spacings": dists,
            "ca_spacing_mean": float(np.nanmean(dists)) if len(dists)>0 else np.nan,
            "osculating_radii": radii,
            "osculating_r_min": float(np.nanmin(radii)) if radii.size>0 else np.nan,
            "osculating_r_median": float(np.nanmedian(radii)) if radii.size>0 else np.nan,
        }
        # flags
        flags = []
        # If the curvature goes along with a 0 deg torsion (that preserves H-bonds), then it is a beta-helix and it's ok
        if np.sum((curvature_radii <= th["curv_rad_local_fail"])) >= 2 and not np.any((curvature_radii <= th["curv_rad_local_fail"]) & (np.abs(torsions) >= th["torsion_deg_fail"])):
            flags.append("bend_local_fail")
        elif np.any((curvature_radii <= th["curv_rad_local_warn"])) and not np.any((curvature_radii <= th["curv_rad_local_fail"]) & (np.abs(torsions) >= th["torsion_deg_warn"])):
            flags.append("bend_local_warn")
        if np.sum((np.abs(torsions) >= th["torsion_deg_fail"]) & (180 - np.abs(torsions) >= th["torsion_deg_fail"])) >= 2:
            flags.append("torsion_local_fail")
        elif np.any((np.abs(torsions) > th["torsion_deg_fail"]) & (180 - np.abs(torsions) >= th["torsion_deg_warn"])):
            flags.append("torsion_local_warn")
        if stats["ca_spacing_mean"] and (stats["ca_spacing_mean"] < th["ca_spacing_min"] or stats["ca_spacing_mean"] > th["ca_spacing_max"]):
            flags.append("ca_spacing_out_of_range")
        results["strands"].append({"idx": s_idx, "res": strand, "stats": stats, "flags": flags})

    return results


def get_biopython_model1(fp):
    ext = os.path.splitext(fp)[1].lower()
    if ext in ('.cif', '.mmcif'):
        parser = MMCIFParser(QUIET=True)
    else:
        parser = PDBParser(QUIET=True)
    structure = parser.get_structure(os.path.basename(fp), fp)

    # we'll process first model only
    model = next(structure.get_models())

    # flatten residues into an ordered list of (chain_id, residue, resname)
    chain_res_lists = []
    for chain in model:
        residues = [r for r in chain if is_aa(r, standard=True)]
        if residues:
            chain_res_lists.append((chain.get_id(), residues))

    return model, chain_res_lists


def make_screenshot(pdb_file, chain='A', start=60, end=75, output='screenshot.png'):
    """
    Load a PDB file, color by B-factor, highlight a segment, and save a screenshot.
    """
    # Clear any previous session
    cmd.reinitialize()

    # Load the structure
    cmd.load(pdb_file, "mol")

    # Color the structure by B-factor
    cmd.set("depth_cue", 0)
    cmd.set("ray_trace_fog", 0)
    cmd.show("cartoon", "mol")
    cmd.spectrum("b", "red_white_blue", "mol")  

    # Highlight the residue range
    selection = f"chain {chain} and resi {start}-{end}"
    cmd.select("highlight", selection)
    cmd.set("stick_transparency", 0.4, "mol")
    cmd.show("sticks", "highlight")
    #cmd.color("grey", "highlight")

    # Make the rest semi-transparent for clarity
    #cmd.set("cartoon_transparency", 0.4, "mol")
    cmd.set("ray_opaque_background", 0)
    cmd.center("highlight")
    cmd.zoom("highlight", 10)

    # White background (not transparent)
    cmd.bg_color("white")
    cmd.set("ray_opaque_background", 1)

    # Ray trace and save the image
    cmd.set("antialias", 2)
    cmd.set("ray_trace_mode", 1)
    cmd.png(output, width=1200, height=900, ray=1)
    print(f"Saved screenshot as {output}")


def screenshot_segment(
    pdb_file,
    chain,
    start_resi,
    end_resi,
    screenshot_path="segment.png",
    cutoff=5.0,
    opaque_color="grey70",
    transparent_color="grey70",
    transparent_level=0.7,
    width=1600,
    height=1200,
):
    """
    Produce a PyMOL screenshot in which:
      - The selected segment is opaque
      - All residues within cutoff Å of the segment are opaque
      - Everything else is transparent

    Works inside PyMOL with 'run script.py' OR via pymol2 in Python.
    """

    # Load structure
    cmd.load(pdb_file, "prot")

    # Define segment
    seg_sel = f"prot and chain {chain} and resi {start_resi}-{end_resi}"
    cmd.select("seg", seg_sel)

    # Nearby residues within cutoff Å
    #cmd.select("nearby", f"byres (seg around {cutoff})")

    # Whole protein except nearby
    cmd.select("rest", "prot and not seg")

    # Representation
    cmd.hide("everything")
    cmd.show("cartoon", "prot")


    # Colors
    cmd.color(opaque_color, "seg")
    #cmd.color(opaque_color, "nearby")
    cmd.color(transparent_color, "rest")

    # Transparency settings
    #cmd.set("cartoon_transparency", 0.0, "nearby")
    cmd.set("cartoon_transparency", transparent_level, "rest")
    cmd.set("cartoon_transparency", 0.0, "seg")
    cmd.spectrum("b", "red_white_blue", "seg")
    cmd.show("sticks", "seg and backbone")

    # Orient camera nicely
    cmd.orient("seg")

    # White background (not transparent)
    cmd.bg_color("white")
    cmd.set("ray_opaque_background", 1)

    # Disable expensive stuff
    cmd.set("ray_trace_mode", 0)
    cmd.set("ray_shadows", 0)
    cmd.set("antialias", 0)

    # Set viewport BEFORE screenshot
    cmd.viewport(width, height)

    cmd.png(screenshot_path, width, height, ray=0)


def vmd_screenshot_segment(
    pdb_file,
    chain,
    start_resi,
    end_resi,
    screenshot="vmd_segment",
    cutoff=5.0,
    transparent_alpha=0.7,
    renderer="TachyonInternal",
    vmd_path="vmd"
):

    segsel = f"protein and chain {chain} and resid {start_resi} to {end_resi}"

    tcl = f"""
mol new "{pdb_file}" type pdb waitfor all
set molID 0

# Define the segment selection string
set segsel "{segsel}"

# Segment atoms
set seg  [atomselect $molID $segsel]

# Atoms within cutoff to segment OR in the same residues
set near [atomselect $molID "protein and (same residue as ($segsel) or within {cutoff} of ($segsel))"]

# Everything else
set rest [atomselect $molID "protein and not (same residue as ($segsel) or within {cutoff} of ($segsel))"]

# Remove default rep
mol delrep 0 $molID

mol representation NewCartoon
mol color ColorID 7
mol selection "protein and not (same residue as ($segsel))"
mol material Transparent
mol addrep $molID

# Opaque target segment
mol representation NewCartoon
mol color Beta
mol selection $segsel
mol material Opaque
mol addrep $molID

# Set transparency
material change opacity Transparent {1.0 - transparent_alpha}

# Camera
# Update selection
$seg update

# 1) Compute geometric center of your segment
set center_xyz [measure center $seg]

# 2) Move the entire molecule so that the segment center goes to the origin
set all [atomselect top "all"]
set neg_center [vecscale -1 $center_xyz]
$all moveby $neg_center

# 3) Reset view + zoom factor
display resetview

# 4) set orthographic projection
display projection orthographic

# 5) zoom by scaling molecule
$all scaleby 10   ;# increase to zoom in, <1 to zoom out

color Display Background white

render {renderer} "{screenshot}.tga"

quit
"""

    with tempfile.NamedTemporaryFile(delete=False, suffix=".tcl", mode="w") as f:
        tcl_path = f.name
        f.write(tcl)

    print(f"Running VMD script: {tcl_path}")
    subprocess.run([vmd_path, "-dispdev", "text", "-e", tcl_path])

    # Convert TGA to final PNG/JPEG
    print("Converting TGA → final format:", f"{screenshot}.png")
    im = Image.open(f"{screenshot}.tga")
    im.save(f"{screenshot}.png")

    os.remove(f"{screenshot}.tga")

    print("Saved:", f"{screenshot}.png")


def print_results(results, entries, chain_id, fn, print_fig=False, fig_dir="."):
    for k in results:
        for r in results[k]:
            fail_flags = ",".join([x for x in r["flags"] if "fail" in x])
            if fail_flags:
                fig_code = "_"
                s, e = entries[r['res'][0]]['res'].id[1], entries[r['res'][-1]]['res'].id[1]
                if print_fig:
                    fig_code = f'{k}_{fn.split("/")[-1].split(".")[0]}_{chain_id}_{s}_{e}'
                    #make_screenshot(fn, chain=chain_id, start=s, end=e, output=f'{fig_dir}/{fig_code}.png')
                    vmd_screenshot_segment(fn, chain_id, s, e, screenshot=f'{fig_dir}/{fig_code}')

                print(f"{fn}\t{chain_id}\t{s}\t{e}\t{k}\t{fail_flags}\t{fig_code}")


# ------------------------
# CLI & batch driver
# ------------------------
def scan_directory(dir_path):
    all_bridges = []
    all_ladders = []
    all_sheets = []
    all_strands = []
    for fname in sorted(os.listdir(dir_path)):
        if not (fname.endswith(".pdb") or fname.endswith(".cif") or fname.endswith(".mmcif") or fname.endswith(".ent")):
            continue
        fp = os.path.join(dir_path, fname)
        #print("FILENAME", fname)

        model, chain_res_lists = get_biopython_model1(fp)

        all_entries = backbone_model(chain_res_lists)
        all_helices, all_strands = detect_helices_strands_per_chain_with_DSSP(model, fp, chain_res_lists)
        
        for chain_id in all_entries:
            #print("CHAIN", chain_id)
            helices = all_helices[chain_id]
            #helices = find_helices(all_entries[chain_id])
            #for helix in helices:
            #    print("HELIX", " ".join([str(x) for x in helix["range"]]))
            h_results = qc_helices({i : entry["CA"] for i, entry in enumerate(all_entries[chain_id])}, [h["range"] for h in helices], [h["turn"] for h in helices]) 
            #print_results(h_results, all_entries[chain_id], chain_id, fname, print_fig=False, fig_dir="test_fig/")
            #continue

            bridges, strands, orientations, beta_sheets, missed_dssp_strands = find_beta_sheets(all_entries[chain_id], all_strands[chain_id])
            s_results = qc_sheets({i : entry["CA"] for i, entry in enumerate(all_entries[chain_id])}, [s["range"] for s in strands] + [s["range"] for s in missed_dssp_strands], beta_sheets, orientation_map=orientations)
            print_results(h_results | s_results, all_entries[chain_id], chain_id, fname, print_fig=False, fig_dir="test_fig/")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Detect beta-sheets/bridges/ladders in a directory of structures.")
    parser.add_argument("dir", help="Directory with PDB/mmCIF files")
    args = parser.parse_args()
    scan_directory(args.dir)
