import os, sys
from Bio.PDB import PDBParser, DSSP

disprot_dir = os.environ["DISPROT_DIR"] #"/workspaces/databases/AFDB/human/disprot/protein_structures/"
disprot_fn = f"{disprot_dir}/disprot_pdb-uniprot.txt"
disprot_seqs_fn = f"{disprot_dir}/disprot-2018-11-disorder.fasta" 

du = {}
with open(disprot_fn) as f:
    for line in f:
        d, u = line.split()
        du[d] = u

flag = False
disprot_seqs = {}
with open(disprot_seqs_fn) as f:
    for line in f:
        if line.startswith(">"):
            d = line.split("|")[1]
            if d in du:
                up = du[d]
                flag = True
            else:
                flag = False
        elif flag and line.strip():
            if up not in disprot_seqs:
                disprot_seqs[up] = ""
            disprot_seqs[up] += line.strip()

dssp_path = os.environ["MKDSSP"]

af_errors = {}
for up in disprot_seqs:
    pdb_fn = disprot_dir + "/" + up + ".pdb"
    if not os.path.exists(pdb_fn):
        #print(f"{pdb_fn} does not exist")
        continue
    structure = PDBParser(QUIET=True).get_structure('', pdb_fn)

    # Calculate DSSP
    dssp = DSSP(structure[0], pdb_fn, dssp=dssp_path, )  # WARNING Check the path of mkdssp
    dssp_dict = dict(dssp)
    #print("DSSP", dssp_dict)

    # Parse RSA and SS code
    rsa_seq = ""
    ss_seq = ""
    for i, residue in enumerate(structure.get_residues()):
        _, _, ch, rid  = residue.get_full_id()
        chrid = (ch, rid)

        rsa = float(dssp_dict.get(chrid)[3])
        if rsa > 0.3:
            rsa_seq += "D"
        else:
            rsa_seq += "-"

        ss = dssp_dict[chrid][2]
        if ss not in ["H", "I", "G", "E"]:
            ss_seq += "D"
        else:
            ss_seq += "-"
        #print("RSA", rsa, chrid, dssp_dict[chrid][2])

    #print(rsa_seq)
    #print(ss_seq)

    if len(disprot_seqs[up]) != len(ss_seq):
        #print(up, "sequences not the same size")
        #print(disprot_seqs[up])
        #print(ss_seq)
        continue

    L = []
    for i, residue in enumerate(structure.get_residues()):
        _, _, ch, rid  = residue.get_full_id()
        chrid = (ch, rid)
        if ss_seq[i] == "D" and disprot_seqs[up][i] == "D":
            L.append(chrid)
        else:
            if len(L) >= 10:
                for chrid_i in L:
                    print(f"{pdb_fn}\t{chrid_i[0]}\t{chrid_i[1][1]}{chrid_i[1][2].strip()}\t1")
            L = []
