#!/usr/bin/env python

from __future__ import with_statement
import os, signal, sys, time, thread, threading, tempfile, pickle, multiprocessing

####################################

argv = sys.argv
try:
    numiteration = int(os.environ['SAGE_TEST_ITER'])
except:
    numiteration = 1

try:
    numglobaliteration = int(os.environ['SAGE_TEST_GLOBAL_ITER'])
except:
    numglobaliteration = 1

SAGE_ROOT = os.environ['SAGE_ROOT']
SAGE_SITE = os.path.realpath(os.path.join(os.environ['SAGE_LOCAL'],
                                          'lib', 'python', 'site-packages'))
BUILD_DIR = os.path.realpath(os.path.join(SAGE_ROOT, 'devel', 'sage', 'build'))

print 'Global iterations: ' + str(numglobaliteration)
print 'File iterations: ' + str(numiteration)

abort = False

def abspath(x):
    """
    This function returns the absolute path (adjusted for NFS)
    """
    return strip_automount_prefix(os.path.abspath(x))

def strip_automount_prefix(filename):
    """
    Strip prefixes added on automounted filesystems in some cases,
    which make the absolute path appear hidden.

    AUTHOR:
        -- Kate Minola
    """
    sep = os.path.sep
    str = filename.split(sep,2)
    if len(str) < 2:
        new = sep
    else:
        new = sep + str[1]
    if os.path.exists(new):
        inode1 = os.stat(filename)[1]
        inode2 = os.stat(new)[1]
        if inode1 == inode2:
            filename = new
    return filename

def abs(f):
    """
    Return the test command for the file
    """
    return "sage -t %s %s"%(opts, abspath(f)[len(SAGE_ROOT)+1:])

def sage_test_cmd(f):
    """
    Return the test command for the file as given
    """
    return "sage -t %s %s"%(opts, f)

def abs_sage_path(f):
    """
    Return the absolute path, relative to the sage root directory
    """
    return abspath(f)[len(SAGE_ROOT)+1:]

def skip(F):
    """
    Returns true if the file should not be tested
    """
    if not os.path.exists(F):
        return True
    G = abspath(F)
    i = G.rfind(os.path.sep)
    if os.path.exists(os.path.join(G[:i], 'nodoctest.py')):
        printmutex.acquire()
        print "%s (skipping) -- nodoctest.py file in directory"%abs(F)
        printmutex.release()
        return True
    filenm = os.path.split(F)[1]
    if (filenm[0] == '.' or (os.path.sep + '.' in G.lstrip(os.path.sep + '.'))
        or 'nodoctest' in open(G).read()[:50]):
        return True
    if G.find(os.path.join('doc', 'output')) != -1:
        return True
    if not (os.path.splitext(F)[1] in ['.py', '.pyx', '.tex', '.pxi', '.sage', '.rst']):
        return True
    return False

def test_file(F):
    """
    This is the function that actually tests a file
    """
    outfile = tempfile.NamedTemporaryFile()
    base, ext = os.path.splitext(F)

    cmd = 'doctest ' + opts
    if SAGE_SITE in os.path.realpath(F) and not '-force_lib' in cmd:
        cmd += ' -force_lib'

    filestr = os.path.split(F)[1]
    for i in range(0,numiteration):
        os.chdir(os.path.dirname(F))
        command = os.path.join(SAGE_ROOT, 'local', 'bin', 'sage-%s' % cmd)
        s = 'bash -c "%s %s > %s" ' % (command, filestr, outfile.name)
        try:
            t = time.time()
            ret = os.system(s)
            finished_time = time.time() - t
            if ret>=256:
                ret=ret/256
            ret = -ret
        except:
            ol = outfile.read()
            return (-5, 0, ol)
        ol = outfile.read()
        if ret != 0:
            break
    return (F, ret, finished_time, ol)

def process_result(result):
    """
    This file takes a tuple in the form
    (F, ret, finished_time, ol)
    and processes it to display/log the appropriate output.
    """
    global failed
    F = result[0]
    ret = result[1]
    finished_time = result[2]
    ol = result[3]
    if ret != 0:
        if ret == -100:
            numfail = ol.count('Expected:') + ol.count('Expected nothing') + ol.count('Exception raised:')
            failed.append(abs(F)+(" # %s doctests failed" % numfail))
            ret = numfail
        elif ret == -4:
            failed.append(abs(F)+" # Exception from doctest framework")
        elif ret == -3:
            failed.append(abs(F)+" # Killed/crashed")
        elif ret == -2:
            failed.append(abs(F)+" # KeyboardInterrupt")
        elif ret == -1:
            failed.append(abs(F)+" # File not found")
        else:
            failed.append(abs(F))
    if abspath(F)[:len(CUR)]==CUR:
        print sage_test_cmd(F[len(CUR)+1:])
    else:
        print abs(F)
    if ol!="" and (not ol.isspace()):
        if (ol[len(ol)-1]=="\n"):
            ol=ol[0:len(ol)-1]
        print ol
    time_dict[abs_sage_path(F)] = finished_time
    print "\t [%.1f s]"%(finished_time)

def infiles_cmp(a,b):
    """
    This compare function is used to sort the list of filenames by the time they take to run
    """
    if time_dict.has_key(abs_sage_path(a)):
        if time_dict.has_key(abs_sage_path(b)):
            return cmp(time_dict[abs_sage_path(a)],time_dict[abs_sage_path(b)])
        else:
            return 1
    else:
        return -1

def populatefilelist(filelist):
    """
    This populates the file list by expanding directories into lists of files
    """
    global CUR
    filemutex.acquire()
    for FF in filelist:
        if os.path.isfile(FF):
            if skip(FF):
                continue
            if not os.path.isabs(FF):
                cwd = os.getcwd()
                files.append(os.path.join(cwd, FF))
            else:
                files.append(FF)
            continue

        curdir = os.getcwd()
        walkdir = os.path.join(CUR,FF)

        for root, dirs, lfiles in os.walk(walkdir):
            for F in lfiles:
                base, ext = os.path.splitext(F)
                if not (ext in ['.sage', '.py', '.pyx', '.tex', '.pxi', '.rst']):
                    continue
                elif '__nodoctest__' in files:
                    continue
                appendstr = os.path.join(root,F)
                if skip(appendstr):
                    continue
                if os.path.realpath(appendstr).startswith(BUILD_DIR):
                    continue
                files.append(appendstr)
            for D in dirs:
                if '#' in D or (os.path.sep + 'notes' in D):
                    dirs.remove(D)
    filemutex.release()
    return 0


for gr in range(0,numglobaliteration):

    opts = ' '.join([X for X in argv if X[0] == '-'])
    argv = [X for X in argv if X[0] != '-']

    try: 
        numthreads = int(argv[1])
        infiles = argv[2:]
    except ValueError: # can't convert first arg to an integer: arg was probably omitted
        numthreads = 1
        infiles = argv[1:]

    if '-sagenb' in opts:
        opts = opts.replace('--sagenb', '').replace('-sagenb', '')

        # Find SageNB's home.
        from pkg_resources import Requirement, working_set
        sagenb_loc = working_set.find(Requirement.parse('sagenb')).location

        # In case we're using setuptools' "develop" mode.
        if not SAGE_SITE in sagenb_loc:
            opts += ' -force_lib'

        infiles.append(os.path.join(sagenb_loc, 'sagenb'))

    if numthreads == 0:
        # Set numthreads to be the number of processors, with a default
        # maximum of 8.
        #
        # The detection of number of processors might not be reliable on some
        # platforms. On a Sun SPARC T5240 (t2.math), the reported number of
        # processors might not correspond to the actual number of processors.
        # See tickets #6283, #7011, and the file SAGE_ROOT/makefile.
        #
        # WARNING: If cpu_count() below reports <= 8 for your machine
        # and you *don't* want to use all the cores/processors on your
        # system for parallel doctesting, use a (sensible) positive
        # integer.
        #
        # If cpu_count() reports > 8 and you want to use that many
        # threads, you must manually specify the number of threads.
        try:
            numthreads = min(8, multiprocessing.cpu_count())
        except NotImplementedError:
            numthreads = 1

    if numthreads < 1 or len(infiles) == 0:
        if numthreads < 1:
            print "Usage: sage -tp <numthreads> <files or directories>: <numthreads> must be positive."
        else:
            print "Usage: sage -tp <numthreads> <files or directories>: no files or directories specified."
        print "For more information, type 'sage -advanced'."
        sys.exit(1)

    infiles.sort()

    files = list()

    t0 = time.time()
    filemutex = thread.allocate_lock()
    printmutex = thread.allocate_lock()
    #Pick a filename for the timing files -- long vs normal
    if opts.count("-long"):
        time_file_name = os.path.join(os.environ["SAGE_TESTDIR"],
                                      '.ptest_timing_long')
    else:
        time_file_name = os.path.join(os.environ["SAGE_TESTDIR"],
                                      '.ptest_timing')
    time_dict = { }
    try:
        with open(time_file_name) as time_file:
            time_dict = pickle.load(time_file)
        if opts.count("-long"):
            print "Using long cached timings to run longest doctests first."
        else:
            print "Using cached timings to run longest doctests first."
        from copy import copy
        time_dict_old = copy(time_dict)
    except:
        time_dict = { }
        if opts.count("-long"):
            print "No long cached timings exist; will create for successful files."
        else:
            print "No cached timings exist; will create for successful files."
        time_dict_old = None
    done = False

    CUR = abspath(os.getcwd())

    failed = []

    DOT_SAGE=os.environ['DOT_SAGE']
    TMP = os.path.join(DOT_SAGE, 'tmp', 'test')
    if not os.path.exists(TMP):
        os.makedirs(TMP)

    populatefilelist(infiles)
    #Sort the files by test time
    files.sort(infiles_cmp)
    files.reverse()
    interrupt = False

    numthreads = min(numthreads, len(files))  # don't use more threads than files
        
    if len(files) == 1:
        file_str = "1 file"
    else:
        file_str = "%i files" % len(files)
    if numthreads == 1:
        jobs_str = "using 1 thread"  # not in parallel if numthreads is 1
    else:
        jobs_str = "doing %s jobs in parallel" % numthreads

    print "Doctesting %s %s" % (file_str, jobs_str)
        
    try:
        from multiprocessing import Pool
        p = Pool(numthreads)
        for r in p.imap_unordered(test_file, files):
            #The format is  (F, ret, finished_time, ol)
            process_result(r)
    except KeyboardInterrupt:
        interrupt = True
        pass
    print " "
    print "-"*int(70)

    os.chdir(CUR)
    
    if len(failed) == 0:
        if interrupt == False:
            print "All tests passed!"
        else:
            print "Keyboard Interrupt: All tests that ran passed."
    else:
        if interrupt:
            print "Keyboard Interrupt, not all tests ran"
        elif opts=="-long" or len(opts)==0:
            time_dict_ran = time_dict
            time_dict = { }
            failed_files = { }
            if opts=="-long":
                for F in failed:
                    failed_files[F.split('#')[0].split()[3]] = None
            else:
                for F in failed:
                    failed_files[F.split('#')[0].split()[2]] = None
            for F in time_dict_ran:
                if F not in failed_files:
                    time_dict[F] = time_dict_ran[F]
            if time_dict_old is not None:
                for F in time_dict_old:
                    if F not in time_dict:
                        time_dict[F] = time_dict_old[F]
        print "\nThe following tests failed:\n"
        for i in range(len(failed)):
               print "\t", failed[i]
        print "-"*int(70)
    #Only update timings if we are doing something standard
    if (opts=="-long" or len(opts)==0) and not interrupt:
        with open(time_file_name,"w") as time_file:
            pickle.dump(time_dict, time_file)
            print "Timings have been updated."

    print "Total time for all tests: %.1f seconds"%(time.time() - t0)
