# Finds linear plasmids from unabridged unicycler assembley graphs
# Alex Tokolyi, 2018

import argparse
import pathlib
import sys
import igraph

parser = argparse.ArgumentParser(description='LETS FIND SOME LINEAR PLASMIDS')
parser.add_argument('--gfa', required=True, type=pathlib.Path)
args = parser.parse_args()
gfa = args.gfa

graph = igraph.Graph(directed=False)
verts = set()
edges = []

# Put node names in set and make 2d array of edges
with open(gfa, 'r') as raw:
    for line in raw:
        l = line.split('\t')
        if l[0]=='L':
            v1 = l[1] + l[2]
            v2 = l[3] + l[4]
            verts.add(v1)
            verts.add(v2)
            edges.append([v1,v2])
            
for v in verts:
    graph.add_vertex(v)
for e in edges:
    graph.add_edge(e[0],e[1])

def psg(sg):
    print(gfa, end=" (")
    x=""
    for j,v in enumerate(subgraph.vs, 0):
        x=v.attributes()["name"]
    print(x[:-1],end=")\n")

# Scan for our type..
for i, subgraph_nodes in enumerate(graph.clusters(), 1):
    subgraph = graph.induced_subgraph(subgraph_nodes, implementation='create_from_scratch')
    
    # 3 verticies and 2 edges (main > end +, main > end -)
    if len(subgraph.vs)==3 and len(subgraph.get_edgelist())==2:
        sub_verts = set()
        vn = []
        
        # Check the ends are reciprocal (same number but +/-)
        for v in subgraph.vs:
            sub_verts.add(v.attributes()["name"][0:-1])
        
        if len(sub_verts)==2:
            for v in subgraph.get_edgelist():
                vn.append(v[0])
                vn.append(v[1])
            mode = max(set(vn), key=vn.count)
            big = subgraph.vs()[mode].attributes()["name"][0:-1]
            count = 0
            
            # Check that the main does not have any other connections on opposite strand
            for v in verts:
                if big==v[0:-1]:
                    count+=1
                    if count>1:
                        break
            if count==1:
                psg(subgraph)       
