#!/usr/bin/env python

"""
Given the output of doctest and a file, adjust the doctests so they won't fail.

Doctest failures due to exceptions are ignored.
"""

import os, sys

if not (len(sys.argv) in [2,3]) :
    print "Usage: sage -fixdoctests <source file> [doctest output]"
    print "Creates file <source file>.out that would pass the doctests (modulo raised exception)"
    sys.exit(1)

if len(sys.argv) == 2:
    os.system('sage -t %s > .tmp'%sys.argv[1])
    doc_out = '.tmp'
else:
    doc_out = sys.argv[2]

src_in = open(sys.argv[1]).read()
doc_out = open(doc_out).read()

sep = "**********************************************************************\n"

doctests = doc_out.split(sep)
src_in_lines = src_in.split('\n')

v = src_in.splitlines()
    
for X in doctests:
    if 'raceback' in X: continue

    # Extract the line, what was expected, and was got.
    i0 = X.find('line ')
    i = X.find(':\n')
    if i == -1: continue
    j = X.find('Expected:\n')
    if j == -1:
        continue
    k = eval(X[i0+5:i])
    line = X[i+1:j].strip()
    X = X[j+10:]
    expected, got = X.split('\nGot:\n')
    expected = expected.splitlines()
    got = got.splitlines()
    # Guess compute by how much got should be further indented to match with the
    # indentation of v
    n = (len(v[k]) - len(v[k].lstrip())) - (len(got[0]) - len(got[0].lstrip()))

    # Double check that expected was indeed there in the file
    if n < 0 or any( expected[i].strip() != v[k+i].strip() for i in range(len(expected)) ):
        import warnings
        warnings.warn("Did not manage to replace %s with %s"%('\n'.join(expected), '\n'.join(got)))
        continue

    # Stuff all of `got` in v[k] and mark the remaining `expected` lines as to be ignored
    # so as to preserve the line numbering
    v[k] = '\n'.join( (' '*n + got[i]) for i in range(len(got)) )
    for i in range(1, len(expected)):
        v[k+i] = None

src_out = ''.join(line+'\n' for line in v if line is not None)

open(sys.argv[1] + '.out','w').write(src_out)

os.system('diff -u %s %s | more'%(sys.argv[1], sys.argv[1] + '.out'))


