#!/usr/bin/env python
import sys, re
"""
gapcode.py

Recode distinct gaps in an alignment as binary presence-absence
characters.  This is an implementation of the 'simple indel coding'
described in Simmons, MP and H
Ochoterena. 2000. Syst. Biol. 49:369-381.

Written by Rick Ree <rree@fieldmuseum.org>
Version 2.1, 13 January 2008
License: Public Domain

Thanks to Pieter Pelser for helpful suggestions.

To use, prepare an input file with one or more non-interleaved
sequence blocks in this format:

[gene1]
seq1  CCCAAA----GGGGG------CCCCAAATTTTTTTTTTT
seq2  CCCAAA----GGGGG------CCCCAAATTTTTTTTTTT
seq3  CCCAAAAA--GGGGG------CCCCAAATTTT---TTTT
seq4  CCCAAAAA--GGGGGGG------CCAAATTTT---TTTT
seq5  --CAAAAAAAGGGGGGG------CCAAATTTT---TTTT
seq6  --CAAAAA??GGGGGGGAAAACCCCAAATTTTTTTTTTT

[gene2]
seq1  CC--AAAAAAGGGGGGGGGGGCCCCA---TTTTTTTTTT
seq2  CC--AAAAAAGGGGGGGG?????????????????????
seq3  CCCAA----AGGGGGGGGGGG------ATTTTTTTTTTT
seq4  CCCAA----AGGGGGGGGGGGCC--AAATTTTTTTTT--
seq5  CCCAAAAAAAGGGGGGGGGGGCC--AAATTTTTTTTT--
seq6  CCCAAAAAAAGGGGGGGAAAACC--AAATTTTTTTTT--

[gene3] etc...

The input file should not contain any other comments, data, etc.  Gene
blocks, if more then one, should not be broken up by blank lines, but
should be separated from each other by one or more blank lines.

Usage is then as follows:

  python gapcode.py < input.txt > output.txt

where output.txt contains the original data, appended by a set of gap
characters for each gene block:

[gene1: 39 chars, sites 1-39]
[5 gapchars: matrix chars 79-83]
[gapchar 1 (matrix char 79): gene1 pos 7-10, matrix pos 7-10]
[gapchar 2 (matrix char 80): gene1 pos 9-10, matrix pos 9-10]
[gapchar 3 (matrix char 81): gene1 pos 16-21, matrix pos 16-21]
[gapchar 4 (matrix char 82): gene1 pos 18-23, matrix pos 18-23]
[gapchar 5 (matrix char 83): gene1 pos 33-35, matrix pos 33-35]

[             ]
[            .]
seq1     1?100
seq2     1?100
seq3     01101
seq4     01011
seq5     00011
seq6     0?000

[gene2: 39 chars, sites 40-78]
[5 gapchars: matrix chars 84-88]
[gapchar 1 (matrix char 84): gene2 pos 3-4, matrix pos 42-43]
[gapchar 2 (matrix char 85): gene2 pos 6-9, matrix pos 45-48]
[gapchar 3 (matrix char 86): gene2 pos 22-27, matrix pos 61-66]
[gapchar 4 (matrix char 87): gene2 pos 24-25, matrix pos 63-64]
[gapchar 5 (matrix char 88): gene2 pos 27-29, matrix pos 66-68]

[             ]
[            .]
seq1     10001
seq2     10???
seq3     011?0
seq4     01010
seq5     00010
seq6     00010

[Total of 10 gaps coded as binary characters]
[CHARSET gene1_gapchars = 79-83]
[CHARSET gene2_gapchars = 84-88]
[CHARSET gapchars_all = 79-88]

* NOTE that 

"""

### editable options
GAPCHAR = "-"
MISSING = "?"
CONVERT_ENDGAPS_TO_MISSING = True
###

P1 = re.compile("^%s+" % GAPCHAR)
P2 = re.compile("%s+$" % GAPCHAR)

def findgaps(s):
    pat = re.compile("[%s]+" % GAPCHAR)
    gaps = []
    for m in pat.finditer(s):
        gaps.append((m.start(), m.end()-1))
    return gaps

def gapcode_matrix(taxa, matrix, position):
    """
    taxa: list of taxa
    matrix: list of sequences corresponding to taxa
    position: number assigned to first site in matrix

    returns a list of coded gap characters
    """
    gapmat = [ findgaps(seq) for seq in matrix ]
    gaps = {}
    for gaplist in gapmat:
        for g in gaplist:
            gaps[g] = 1
    gaps = gaps.keys(); gaps.sort()
    coded = []
    for gap in gaps:
        gaplen = gap[1] - gap[0] + 1
        local_pos = (gap[0]+1, gap[1]+1)
        global_pos = (gap[0]+position, gap[1]+position)
        score = ["0"]*len(taxa)
        for i, gaplist in enumerate(gapmat):
            subseq = matrix[i][gap[0]:gap[0]+gaplen]
            if subseq == MISSING*gaplen:
                score[i] = "?"
                continue
            for g in gaplist:
                if gap == g:
                    score[i] = "1"
                    break
                elif (gap[0] >= g[0]) and (gap[1] <= g[1]):
                    score[i] = "?"
        coded.append((local_pos, global_pos, score))
    return coded

def scalebar(taxlen, seqlen):
    segments = seqlen/10
    s = "".join([ str((i+1)*10).ljust(10, " ") for i in range(segments) ])
    s = "[%s]" % ((" "*(8+taxlen)) + s)[:(seqlen+taxlen-1)]
    t = " "*(taxlen-1) + (".".join([ "    .    " for i in range(segments+1) ]))
    t = "[%s]" % t[:(seqlen+taxlen-1)]
    return "\n".join((s,t))
                    
def report(blocks):
    align_end = blocks[-1][3] + blocks[-1][4] - 1 # number of the last char
    gapchar_start = align_end + 1
    total_ncoded = 0
    gap_charsets = []
    for label, taxa, matrix, glopos, seqlen, coded in blocks:
        maxtaxlen = max([ len(t) for t in taxa ]) + 4
        print scalebar(maxtaxlen+1, seqlen)
        for t, s in zip(taxa, matrix):
            print t.ljust(maxtaxlen), s
        print
        
    for label, taxa, matrix, glopos, seqlen, coded in blocks:
        maxtaxlen = max([ len(t) for t in taxa ]) + 4
        ncoded = len(coded)
        print "[%s: %s chars, sites %s-%s]" % \
              (label, seqlen, glopos, glopos+seqlen-1)
        if ncoded > 0:
            print "[%s gapchars: matrix chars %s-%s]" \
                  % (ncoded, align_end + 1, align_end + ncoded)
            gap_charsets.append("[CHARSET %s_gapchars = %s-%s]" % \
                                (label, align_end + 1, align_end + ncoded))
            for i, data in enumerate(coded):
                loc, glo, score = data
                if loc[0] == loc[1]:
                    print "[gapchar %s (matrix char %s): %s pos %s, "\
                          "matrix pos %s]" \
                          % (i+1, align_end+i+1, label, loc[0], glo[0])
                else:
                    print "[gapchar %s (matrix char %s): "\
                          "%s pos %s-%s, matrix pos %s-%s]" \
                          % (i+1, align_end+i+1, label,
                             loc[0], loc[1], glo[0], glo[1])
            print
            print scalebar(maxtaxlen+1, ncoded)
            for i, tax in enumerate(taxa):
                row = "".join([ score[i] for loc, glo, score in coded ])
                print tax.ljust(maxtaxlen), row
            align_end += ncoded
            total_ncoded += ncoded
        else:
            print "[0 gapchars]"
        print
    print "[Total of %s gaps coded as binary characters]" % total_ncoded
    print "\n".join(gap_charsets)
    print "[CHARSET gapchars_all = %s-%s]" % (gapchar_start, align_end)

def parse_matrix(input):
    taxa = []
    seqs = []
    blocks = []
    glopos = 1
    for line in input.split("\n"):
        s = line.strip()
        if s:
            if s.startswith("[") and s.endswith("]"):
                label = s[1:-1]
            else:
                assert label, "found sequence in unlabeled block"
                t, v = line.split()
                if CONVERT_ENDGAPS_TO_MISSING:
                    for p in (P1, P2):
                        m = p.search(v)
                        if m:
                            start, end = m.span()
                            v = list(v)
                            v[start:end] = MISSING*(end - start)
                            v = "".join(v)
                taxa.append(t); seqs.append(v)
                assert len(v) == len(seqs[0])
        else:
            if taxa and seqs and label:
                seqlen = len(seqs[0])
                coded = gapcode_matrix(taxa, seqs, glopos)
                blocks.append((label, taxa, seqs, glopos, seqlen, coded))
                glopos += seqlen
            label = ""
            taxa = []
            seqs = []
            
    return blocks

if __name__ == "__main__":
    data = """
    [gene1]
    seq1  CCCAAA----GGGGG------CCCCAAATTTTTTTTTTT
    seq2  CCCAAA----GGGGG------CCCCAAATTTTTTTTTTT
    seq3  --CAAAAA--GGGGG------CCCCAAATTTT---TTTT
    seq4  --CAAAAA--GGGGGGG------CCAAATTTT---TTTT
    seq5  CCCAAAAAAAGGGGGGG------CCAAATTTT---TTT-
    seq6  CCCAAAAA??GGGGGGGAAAACCCCAAATTTTTTTTTT-

    [gene2]
    seq1  CC--AAAAAAGGGGGGGGGGGCCCCA---TTTTTTTTTT
    seq2  CC--AAAAAAGGGGGGGG?????????????????????
    seq3  CCCAA----AGGGGGGGGGGG------ATTTTTTTTTTT
    seq4  CCCAA----AGGGGGGGGGGGCC--AAATTTTTTTTTTT
    seq5  CCCAAAAAAAGGGGGGGGGGGCC--AAATTTTTTTTTTT
    seq6  CCCAAAAAAAGGGGGGGAAAACC--AAATTTTTTTTTTT
    """
    data = sys.stdin.read()
    blocks = parse_matrix(data)
    report(blocks)
