from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw import MolsToGridImage
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
IPythonConsole.ipython_useSVG = False
DrawingOptions.includeAtomNumbers = False
def show_mol_grid(mols):
    return MolsToGridImage(
        [x.asRDMol() for x in mols]
        , subImgSize=(250,200)
        , legends=[x.parent_operator for x in mols]
    )
import molpher
from molpher.core.operations import *
from molpher.core import MolpherMol, ExplorationTree as ETree
class MyFilterMorphs(TreeOperation):
    """
    A custom tree operation that accepts
    only the first ten morphs after 
    the list of candidates is sorted.
    """
    def __call__(self):
        """
        This method is called automatically by the tree.
        The tree this operation is being run on is accessible
        from the 'tree' member of the class.
        """
        self.tree.candidates_mask = [
            True if idx < 10 and self.tree.candidates[idx].sascore < 6 
            else False 
            for idx, x in enumerate(self.tree.candidates_mask)
        ]
cocaine = MolpherMol('CN1[CH]2CC[CH]1[CH](C(OC)=O)[CH](OC(C3=CC=CC=C3)=O)C2')
procaine = MolpherMol('O=C(OCCN(CC)CC)c1ccc(N)cc1')
tree = ETree.create(source=cocaine, target=procaine) # create the tree
# list of tree operations, defines one iteration
iteration = [
    GenerateMorphsOper()
    , SortMorphsOper()
    , MyFilterMorphs() # our custom filtering procedure
    , ExtendTreeOper()
    , PruneTreeOper()
]
# apply the operations in the list one by one
for oper in iteration:
    tree.runOperation(oper)
# observe the results
print(tree.generation_count)
print(len(tree.leaves))
show_mol_grid(tree.leaves)
from molpher.core.operations.callbacks import TraverseCallback
class MyCallback(TraverseCallback):
    """
    This callback just prints some information
    about the molecules in the tree.
    """
    def __call__(self, morph):
        """
        Method called on each morph in the tree
        -- starting from the root to leaves.
        """
        if not morph.getParentSMILES():
            print("# Root #")
        else:
            print('# Morph #')
            print('Parent:', morph.getParentSMILES())
        print('SMILES: ', morph.getSMILES())
        print('Descendents: ', morph.getDescendants())
callback = MyCallback() # initialize a callback
traverse = TraverseOper(callback=callback) # attach it to a tree traversal operation
tree.runOperation(traverse) # run the operation
def process(morph):
    """
    Prints some information
    about the molecules in the tree.
    """
    if not morph.getParentSMILES():
        print("# Root #")
    else:
        print('# Morph #')
        print('Parent:', morph.getParentSMILES())
    print('SMILES: ', morph.getSMILES())
    print('Descendents: ', morph.getDescendants())
tree.traverse(process) # use the traverse method to run the callback function
template_file = 'cocaine-procaine-template.xml'
tree = ETree.create(template_file)
print(tree.params)
# apply the tree operations
for oper in iteration:
    tree.runOperation(oper)
print(
    sorted( # grab the new leaves as a list sorted according to their distance from target
    [
        (x.getSMILES(), x.getDistToTarget())
        for x in tree.leaves
    ], key=lambda x : x[1]
    )
)
# save the tree in a snapshot file
tree.save('snapshot.xml')
new_tree = ETree.create('snapshot.xml') # create a new tree from the saved snapshot
print(new_tree.params)
sorted( # grab the leaves in the created tree (these should be the same as those in the original tree)
    [
        (x.getSMILES(), x.getDistToTarget())
        for x in new_tree.leaves
    ], key=lambda x : x[1]
)
iteration = [
    GenerateMorphsOper()
    , SortMorphsOper()
    , MyFilterMorphs()
    , ExtendTreeOper()
    , PruneTreeOper()
]
tree = ETree.create(source=cocaine, target=procaine)
counter = 0
while not tree.path_found:
    for oper in iteration:
        tree.runOperation(oper)
    counter+=1
    print("Iteration", counter)
    print(
        sorted(
        [
            (x.getSMILES(), x.getDistToTarget())
            for x in tree.leaves
        ], key=lambda x : x[1]
        )[0]
    )
show_mol_grid(tree.fetchPathTo(tree.params['target']))