#!/usr/bin/env python3

import os
import re
import sys
import subprocess
import argparse
from collections import defaultdict
from shutil import which

# default name of program to execute, can be overridden by the same-named environment variable or via
# commandline option
DARKTABLE_CLI = 'darktable-cli'

# default name of directory in which to create a scratch folder for darktable to use as config directory,
# can be overridden by the same-named environment variable, environment variable TMPDIR, or via commandline option
DARKTABLE_TMP = '/tmp'

VERBOSE = False

def whereami():
   '''whereami: retrieve the path for this script

   args: none
   returns: str(path)
   '''
   path, _, exe = sys.argv[0].rpartition('/')
   if exe == '':
      ## invoked without a pathname, so assume running from current directory
      return os.getcwd()
   path = path + '/'
   if path[0] == '/' or path[0] == '\\':
      return path
   else:
      return os.getcwd() + '/' + path

def locate_program(program):
   '''locate executable in standard locations if specified without a path

   args: program = the name of the program to locate
   returns: full pathname of program
   '''
   global VERBOSE
   if '/' in program or '\\' in program:
      # is this a relative or absolute path?
      if program[0] == '/' or program[0] == '\\':
         return program
      # relative path, so prepend current working directory
      return os.getcwd() + '/' + program
   else:
      # no path given, so check the standard locations
      loc = whereami()
      # are we in the source tree?
      if 'src/tests/benchmark' in loc:
         # check if available in the build directory
         build, _, __ = loc.partition('src/tests/benchmark')
         build += 'build/bin/'
         if os.path.exists(build+program):
            return build+program
         if VERBOSE:
            print(f'  did not find {program} in build/bin')
      # are we in a sibling of the source tree?
      if os.path.exists('../darktable/src/tests/benchmark'):
         build = '../darktable/build/bin/'
         if os.path.exists(build+program):
            return build+program
         if VERBOSE:
            print(f'  did not find {program} in {build}')
      # last location: check if on search path
      onpath = which(program)
      if onpath:
         return onpath
      if VERBOSE:
         print(f'  did not find {program} on search path')
   # finally: give up
   print(f'Unable to locate {program}')
   exit(1)

def locate_image(image):
   '''locate benchmark image in standard locations if specified without a path

   args: image = the name of the image file to locate
   returns: full pathname of image
   '''
   global VERBOSE
   if '/' in image or '\\' in image:
      # is this a relative or absolute path?
      if image[0] == '/' or image[0] == '\\':
         return image
      # relative path, so prepend current working directory
      return os.getcwd() + '/' + image
   else:
      # check standard locations if no path given
      # first: current directory
      if os.path.exists(image):
         return os.getcwd()+ '/' + image
      if VERBOSE:
         print(f'  did not find {image} in current directory')
      # second: same directory as this script
      loc = whereami()
      if os.path.exists(loc+image):
         return loc+image
      if VERBOSE:
         print(f'  did not find {image} in {loc}')
      # third: script is in source dir, so look in integration tests
      if 'src/tests/benchmark' in loc:
         integ, _, __ = loc.partition('benchmark')
         integ += 'integration/images/'
         if os.path.exists(integ+image):
            return integ+image
         if VERBOSE:
            print(f'  did not find {image} in {integ}')
      # fourth: we are in integration-test dir and script is in source dir, so look in integration tests
      if 'integration/../benchmark' in loc:
         integ, _, __ = loc.partition('/../benchmark')
         integ += '/images/'
         if os.path.exists(integ+image):
            return integ+image
         if VERBOSE:
            print(f'  did not find {image} in {integ}')
      # fifth: script is not in source dir, but current dir is top of source tree
      img = os.getcwd() + '/src/tests/integration/images/mire1.cr2'
      if os.path.exists(img):
         return img
      # sixth: we are in a sibling of the darktable source tree
      candidate = os.getcwd() + '/../darktable/src/tests/integration/images/mire1.cr2'
      if os.path.exists(candidate):
         return candidate
   # finally: give up
   print(f'Unable to locate {image}')
   exit(1)

def locate_xmp(xmp,version):
   '''locate sidecar file in standard locations if specified without a path

   args: image = the name of the sidecar file to locate
   returns: full pathname of sidecar
   '''
   global VERBOSE
   if xmp.endswith('.xmp'):
      xmp = xmp[:-4]
   if version:
      xmp = xmp + '-' + version + '.xmp'
   if '/' in xmp or '\\' in xmp:
      # is this a relative or absolute path?
      if xmp[0] == '/' or xmp[0] == '\\':
         return xmp
      # relative path, so prepend current working directory
      return os.getcwd() + '/' + xmp
   else:
      # check standard locations if no path given
      # first: current directory
      if os.path.exists(xmp):
         return os.getcwd() + '/' + xmp
      if VERBOSE:
         print(f'  did not find {xmp} in current directory')
      # second: same directory as this script
      loc = whereami()
      if os.path.exists(loc+xmp):
         return loc+xmp
      if VERBOSE:
         print(f'  did not find {xmp} in {loc}')
      # third: script is not in source dir, but current dir is top of source tree
      xmp_path = os.getcwd() + '/src/tests/benchmark/' + xmp
      if os.path.exists(xmp_path):
         return xmp_path
      if VERBOSE:
         print(f'  did not find sidecar at {xmp_path}')
      # fourth: TODO?

   # finally: give up
   print(f'Unable to locate sidecar {xmp}')
   exit(1)


def parse_commandline():
   global DARKTABLE_CLI, DARKTABLE_TMP
   if 'DARKTABLE_CLI' in os.environ:
      DARKTABLE_CLI = os.environ['DARKTABLE_CLI']
   if 'TMPDIR' in os.environ:
      DARKTABLE_TMP = os.environ['TMPDIR']
   if 'DARKTABLE_TMP' in os.environ:
      DARKTABLE_TMP = os.environ['DARKTABLE_TMP']
   parser = argparse.ArgumentParser(description="Darktable performance benchmarking")
   parser.add_argument("-i","--image",metavar="FILE",help="the name of the image to use",default="mire1.cr2")
   parser.add_argument("-v","--version",metavar="V",help="look for darktable version V sidecar",default="3.6")
   parser.add_argument("-x","--xmp",metavar="FILE",help="the root name of the .xmp sidecar file to use",default="darktable-bench")
   parser.add_argument("-p","--program",metavar="EXE",help="full path to darktable-cli executable",default=DARKTABLE_CLI)
   parser.add_argument("-r","--reps",metavar="N",help="run N times and report average time",type=int,choices=range(1,10),default=3)
   parser.add_argument("-t","--threads",metavar="N",help="tell darktable-cli to use N threads",default=None)
   parser.add_argument("-C","--cpuonly",action="store_true",help="disable OpenCL GPU acceleration",default=False)
   parser.add_argument("-T","--tempdir",metavar="DIR",help="directory in which to create test data",default=DARKTABLE_TMP)
   parser.add_argument("-I","--iopstats",metavar="FILE",help="file where per-iop times should be written (as CSV)",default=None)
   parser.add_argument("--verbose",action="store_true")
   if len(sys.argv) < 1:
      parser.print_usage()
      parser.exit()
   args, remargs = parser.parse_known_args()
   global VERBOSE
   VERBOSE = True if args.verbose else False
   args.program = locate_program(args.program)
   args.image = locate_image(args.image)
   _, _, args.image_base = args.image.rpartition('/')
   if not args.image_base:
      args.image_base = args.image
   args.xmp0 = args.xmp
   args.xmp = locate_xmp(args.xmp,args.version)
   if remargs:
      parser.print_usage()
      parser.exit()
   if os.path.exists(args.tempdir):
      args.tempdir = args.tempdir + '/dtbench' + str(os.getpid())
      os.mkdir(args.tempdir)
   else:
      os.mkdir(args.tempdir)
   if VERBOSE:
      print(f'  found:')
      print(f'     {args.program}')
      print(f'     {args.image}')
      print(f'     {args.xmp}')
   return args, remargs

def extract_seconds(line):
   pos = line.find('took')
   if pos > 0:
      line = line[pos+4:]
   else:
      return 0.0
   pos = line.find('sec')
   if pos > 0:
      line = line[:pos]
   else:
      return 0.0
   return float(line.strip())
   
def run_benchmark(program,image,xmp,args):
   confdir=args.tempdir
   outimage=args.tempdir+'/darktable-bench.png'
   args.outimage=outimage
   if os.path.exists(outimage):
      os.remove(outimage)
   arglist = ["--hq","1",image,xmp,outimage,"--core","--library",":memory:","--configdir",confdir,"-d","perf"]
   if args.threads:
      arglist = arglist + ["-t",args.threads]
   if args.cpuonly:
      arglist = arglist + ["--disable-opencl"]
   os.environ['LANG'] = 'C'
   os.environ['LC_ALL'] = 'C'
   try:
      trace = subprocess.check_output([program]+arglist,stdin=None,stderr=subprocess.PIPE,env=os.environ)
      pixpipe = 0.0
   except KeyboardInterrupt as e:
      print('KeyboardInterrupt')
      trace = ''
      pixpipe = -1.0
   if trace:
      trace = trace.decode('utf-8').splitlines()
   loadtime = 0.0
   savetime = -1
   gpu = False
   iop_result_regex = re.compile(r"took (\d+\.\d+) .+ processed `(.+)'")
   iop_times = {}
   for t in trace:
      if 'GPU' in t:
         gpu = True
      if 'to load the image' in t:
         loadtime = extract_seconds(t)
      elif 'to save the image' in t:  ## dt doesn't yet report file write times, so this does nothing right now
         savetime = extract_seconds(t)
      elif 'pipeline processing took' in t:
         pixpipe = extract_seconds(t)
      else:
         iop_match = iop_result_regex.search(t)
         if iop_match:
            iop_time = float(iop_match.group(1))
            iop_name = iop_match.group(2)
            iop_times[iop_name] = iop_time
   if savetime < 0:
      savetime = loadtime	# if no reported save time, assume it's the same as the time to load the image
   return pixpipe, loadtime+pixpipe+savetime, gpu, iop_times

def warm_up_caches(program,image,xmp,args):
   xmp = locate_xmp(xmp,'null')
   if xmp:
      if VERBOSE:
         print(f'     {xmp}')
      print('Preparing...',end='',flush=True)
      run_benchmark(program,image,xmp,args)
      print('done')

def get_version(program):
   output = subprocess.check_output([program,"--version"],stdin=None,stderr=subprocess.PIPE)
   if output:
      output = output.decode('utf-8').splitlines()
   if 'this is ' in output[0]:
      return output[0][8:].replace('-cli','')
   else:
      return "(Undetermined darktable version)"

def print_performance(pixpipe,total,dtversion,xmpversion,imagename,threads,used_gpu):
   if used_gpu:
      gpu = "using GPU"
   else:
      gpu = "CPU only"
   print('')
   print(f'{dtversion} ::: benchmark v{xmpversion} ::: image {imagename}')
   if threads:
      print(f'Number of threads used:               {threads:>7}')
   print(f'Average pixelpipe processing time:    {pixpipe:7.3f} seconds')
   print(f'Average overall processing time:      {total:7.3f} seconds')
   thruput = 3600 / total
   print(f'Throughput rating (higher is better): {thruput:7.1f} ({gpu})')
   return

def cleanup(args):
   if args.outimage:
      try:
         os.remove(args.outimage)
      except:
         pass
   if args.tempdir:
      try:
         # delete files in temp dir; since the ones darktable-cli creates all start with 'dar' or 'dat', limit the
         #  deletion to such files just in case
         for f in os.listdir(args.tempdir):
            if f.startswith('da') and (f[2] == 'r' or f[2] == 't'):
               os.remove(args.tempdir+'/'+f)
         os.rmdir(args.tempdir)
      except:
         pass
   return

def write_iop_stats(iop_times, filename):
   iop_lines = (
      f"{iop_name}; {iop_time:.3f}\n" for iop_name, iop_time
      in sorted(iop_times.items(), key=lambda kv: kv[0])
   )
   with open(filename, "w") as fout:
      fout.write("iop name; iop execution time (s)\n")
      fout.writelines(iop_lines)

def main():
   args, remargs = parse_commandline()

   warm_up_caches(args.program,args.image,args.xmp0,args)
   total = 0.0
   pixpipe = 0.0
   used_gpu = False
   reps_run = 0
   iop_times = defaultdict(float)
   for rep in range(args.reps):
      if args.reps > 1:
         print('     run #',rep+1,end='')
      p, t, g, iops = run_benchmark(args.program,args.image,args.xmp,args)
      for iop_name, iop_time in iops.items():
         iop_times[iop_name] += iop_time
      if p < 0.0:
         continue
      pixpipe += p
      total += t
      reps_run += 1
      if g:
         used_gpu = True
      if args.reps > 1:
         print(f': {p:7.3f} pixpipe,  {t:7.3f} total')
   total = total / reps_run if reps_run > 0 else 999.9
   pixpipe = pixpipe / reps_run if reps_run > 0 else 999.9
   print_performance(pixpipe,total,get_version(args.program),args.version,args.image_base,args.threads,used_gpu)
   if args.iopstats:
      for iop_name, iop_time in iop_times.items():
         iop_times[iop_name] = iop_time / reps_run if reps_run > 0 else 999.9
      write_iop_stats(iop_times, args.iopstats)
   cleanup(args)
   return

if __name__ == '__main__':
   main()
