#!/usr/bin/python

import os
import sys
import argparse

import requests

def parse_args():
    parser = argparse.ArgumentParser(description='Drive test TPC transfers')
    parser.add_argument("-t", "--token", help="Bearer token for transfer")
    parser.add_argument("--dest-token", help="Bearer token for destination host")
    parser.add_argument("--src-token", help="Bearer token for source host")
    parser.add_argument("--push", dest="mode", action="store_const", const="push", help="Use push-mode for transfer (source manages transfer)", default="auto")
    parser.add_argument("--pull", dest="mode", action="store_const", const="pull", help="Use pull-mode for transfer (destination manages transfer)")
    parser.add_argument("--no-overwrite", dest="overwrite", action="store_false", default=True, help="Disable overwrite of existing files.")
    parser.add_argument("--streams", dest="streams", help="Allow multiple streams", default=1,
                        type=int)
    parser.add_argument("src")
    parser.add_argument("dest")

    args = parser.parse_args()

    if args.streams < 1:
        print >> sys.stderr, "Invalid number of streams specified: %d" % args.streams
        sys.exit(1)

    if not args.token and (not args.src_token or not args.dest_token):
        if 'SCITOKEN' in os.environ and os.path.exists(os.environ['SCITOKEN']):
            args.token = os.environ['SCITOKEN']
        elif os.path.exists(os.path.expanduser("~/.scitokens/token")):
            args.token = os.path.expanduser("~/.scitokens/token")
        elif os.path.exists('/tmp/scitoken_u%d' % os.geteuid()):
            args.token = '/tmp/scitoken_u%d' % os.geteuid()
        else:
            print >> sys.stderr, "No token file found in user environment."
            sys.exit(1)
           
    if not args.src_token:
        args.src_token = args.token
    if not args.dest_token:
        args.dest_token = args.token

    return args

def get_token(fname):
    with open(fname, "r") as fp:
        for line in fp:
            if line.startswith("#"):
                continue
            return line.strip()
    raise Exception("No token found in specified file (%s)" % fname)

def determine_mode(args):
    verbs = requests.options(args.dest)
    if 'allow' in verbs.headers:
        if 'COPY' in verbs.headers['allow'].split(","):
            return 'pull'
    return 'push'

def main():
    args = parse_args()

    src_token = get_token(args.src_token)
    dest_token = get_token(args.dest_token)

    headers = {}
    if args.overwrite:
        headers['Overwrite'] = 'T'
    else:
        headers['Overwrite'] = 'F'
    if args.streams > 1:
        headers['X-Number-Of-Streams'] = str(args.streams)
    mode = args.mode
    if mode == "auto":
        mode = determine_mode(args)
        print "Auto detect determined %s mode" % mode

    with requests.Session() as session:
        session.verify = '/etc/grid-security/certificates'
        if mode == 'pull':
            headers['Authorization'] = 'Bearer %s' % dest_token
            headers['Source'] = args.src
            headers['Copy-Header'] = 'Authorization: Bearer %s' % src_token
            url = args.dest
        else:  # Push
            headers['Authorization'] = 'Bearer %s' % src_token
            headers['Destination'] = args.dest
            headers['Copy-Header'] = 'Authorization: Bearer %s' % dest_token
            url = args.src
        try_again = True
        while try_again:
            resp = session.request('COPY', url, headers=headers, allow_redirects = True)
            print resp.status_code
            print resp.headers
            print resp.text
            break

if __name__ == '__main__':
    main()
