import os, sys
from Bio.PDB import PDBParser, DSSP

dssp_path = os.environ["MKDSSP"]
AIUPRED_DISORDER_THR = 0.3
aiupred_results_fn = os.environ["AIUPRED_RESULTS"]

aiupred_scores = []
with open(aiupred_results_fn) as f:
    read = False
    for line in f:
        if line.startswith(">"):
            read = True
            if not aiupred_scores:
                name = line[1:].strip()
                continue

            pdb_fn = os.environ["AFDB_DIR"] + name
            if not os.path.exists(pdb_fn):
                print(f"{pdb_fn} does not exist")
                aiupred_scores = []
                name = line[1:].strip()
                continue

            structure = PDBParser(QUIET=True).get_structure('', pdb_fn)

            # Calculate DSSP
            dssp = DSSP(structure[0], pdb_fn, dssp=dssp_path, )  # WARNING Check the path of mkdssp
            dssp_dict = dict(dssp)
            #print("DSSP", dssp_dict)

            # Parse RSA and SS code
            ss_seq = ""
            for i, residue in enumerate(structure.get_residues()):
                _, _, ch, rid  = residue.get_full_id()
                chrid = (ch, rid)

                ss = dssp_dict[chrid][2]
                if ss not in ["H", "I", "G", "E"]:
                    ss_seq += "D"
                else:
                    ss_seq += "-"

            if len(ss_seq) != len(aiupred_scores):
                print("ERROR LENGTHS", len(ss_seq), len(aiupred_scores))
                print(ss_seq)
                print(aiupred_scores)
                aiupred_scores = []
                name = line[1:].strip()
                continue

            L = []
            infos = []
            for i, residue in enumerate(structure.get_residues()):
                _, _, ch, rid  = residue.get_full_id()
                chrid = (ch, rid)
                infos.append(chrid)
                if ss_seq[i] == "D":
                    L.append((i, chrid))
                else:
                    if len(L) >= 10:
                        aiupred_mean = 0
                        for j, _chrid in L:
                            aiupred_mean += aiupred_scores[j]
                        aiupred_mean /= len(L)
                        if aiupred_mean < AIUPRED_DISORDER_THR:
                            #print("CATCH", L, ss_seq[L[0][0]:L[0][0]+len(L)-1], aiupred_mean)
                            for j, _chrid in L:
                                print(f"{pdb_fn}\t{_chrid[0]}\t{_chrid[1][1]}{_chrid[1][2].strip()}\t0")
                    L = []

            read = True
            aiupred_scores = []
            name = line[1:].strip()
            continue
        if read:
            if not line.strip():
                read = False
                continue
            _, _, aiupred_score = line.split()
            aiupred_score = float(aiupred_score)
            aiupred_scores.append(aiupred_score)
