#!/usr/libexec/platform-python


#author Justin Sherrill jsherril@redhat.com
#version: 2.0-beta
#
#



import getopt
import sys
import getpass
if sys.version_info[0] == 3:
    from xmlrpc.client import Server
    from xmlrpc.client import Fault
else:
    from xmlrpclib import Server
    from xmlrpclib import Fault
import os
from os import path
from optparse import OptionParser
from rhn.connections import idn_ascii_to_puny
from socket import gethostname
from rhn.i18n import ustr

DATA_DIR = "./data/"
DATA_DIR_RHN = "/usr/share/rhn/channel-data/"
GATHER_DIR = "/mnt/engarchive2/released/"

versions = ( "7", "6", "5", "4", "3", "2.1" )
releases = ( "AS", "ES", "WS", "Server", "Client", "Desktop", "Appliance",
             "ComputeNode", "WebServer", "Workstation")
arches = ( "i386",  "ia64",  "ppc",  "s390",  "s390x",  "x86_64", "ppc64")
updates = ("GOLD", "U1", "U2", "U3", "U4", "U5", "U6", "U7", "U8", "U9", "U10", "U11" )
subrepos = ("VT", "Cluster", "ClusterStorage", "Workstation", "ResilientStorage",
            "ScalableFileSystem", "LoadBalancer", "HighAvailability") #RHEL 5 yum repos on the discs
extras = ("Extras", "Supplementary") #things shipped on other cds


bundleSize = None
server = None
user = None
password = None
filename = None
ssl = True
parser = None
clone = False

client = None

def get_args():
    global parser
    parser = OptionParser(usage="""spacewalk-create-channel  -l USER  -s SERVER [-p PASSWORD]  -v VERSION -r RELEASE
                -u UPDATE -a ARCH -e EXTRA  [ -D DATAFILE]  [-c SRCCHAN] -d DESTCHAN [-N DESTNAME]  [-P PARENT]
                [-C | --clear] [-n | --nossl]""")
    parser.add_option("-n", "--nossl", action="store_false", dest="ssl", default=True, help="disables the use of SSL when connecting to the server")
    parser.add_option("-C", "--clear", action="store_true", dest="clear_channel", help="remove all channel packages from the channel before adding")
    parser.add_option("-s", "--server", action="store", dest="server", default=gethostname(), help="the server to connect to (i.e. hostname.domain.com) [default: %default]")
    parser.add_option("-l", "--user", action="store", dest="user", help="the user to connect as")
    parser.add_option("-p", "--password", action="store", dest="password", help="the password to use for the connection (if not specified, will be prompted for)")
    parser.add_option("-v", "--version", action="store", dest="version", help='the version to use (i.e.' + ', '.join(versions) +')')
    parser.add_option("-r", "--release", action="store", dest="release", help='the release to use (i.e. ' + ', '.join(releases) + ')')
    parser.add_option("-a", "--arch", action="store", dest="arch", help='the arch to use (i.e. ' + ', '.join(arches) + ')')
    parser.add_option("-u", "--update", action="store", dest="update_level", help='the update level to use (i.e. ' + ', '.join(updates)  + ')')
    parser.add_option("-e", "--extra", action="store", dest="extra", help='the child channel/repo to use (i.e. ' + ', '.join(extras) + ')')
    parser.add_option("-D", "--data", action="store", dest="data", help='a data file to use, may use this instead of version, release, update, arch,  and extra')
    parser.add_option("-c", "--sourceChannel", action="store", dest="source_channel", help=' the channel to pull packages from (will be auto detected if not provided)')
    parser.add_option("-d", "--destChannel", action="store", dest="dest_channel", help='the label of the destintation channel to push the packages to (will be created if not existing)')
    parser.add_option("-P", "--parentChannel", action="store", dest="parent_channel", default="", help='if specified, and --destChannel does not exist, --destChannel will be created with this parent')
    parser.add_option("-L", "--clone", action="store_true", dest="clone", default=False, help='if creating destChannel, clone it from the source channel, before adding packages (Satellite 5.1 or newer required)')
    parser.add_option("-K", "--skiplist", action="store", dest="skiplist", help="a filename with a list of package names to ignore when adding to the destination channel")
    parser.add_option("-g", "--gather", action="store_true", dest="gather", default=False , help='used to gather data files, generally not useful to most people')
    parser.add_option("-b", "--bundle", action="store", dest="bundle", type='int', default='50', help='if you are getting "502 proxy" errors, try using a smaller value (default 50)')
    parser.add_option("-N", "--name", action="store", dest="dest_name", help='if specified, set the descriptive name of the destination channel, else use channel label')
    parser.add_option("-o", "--currentState", action="store_true", default=False, dest="currentState", help='option for cloning a channel to the current state')

    return parser.parse_args()


def validate_options(options):
    error = False

    if not options.user:
        error = True
        print("--user must be specified")

    if not options.data:
        if not options.version:
            error = True
            print("--version must be specified")
        if not options.release:
            error = True
            print("--release must be specified")
        if not options.arch:
            error = True
            print("--arch must be specified")
        if not options.update_level:
            error = True
            print("--update must be specified")

    if not options.dest_channel:
        error = True
        print("--destChannel must be specified")

    if error:
        sys.exit(1)


def main():
    (options, args) = get_args()

    if options.gather:
        gather(False)
        sys.exit(0)

    validate_options(options)

    global server, user, password, filename, ssl, bundleSize, clone, currentState
    server = idn_ascii_to_puny(options.server)
    user = options.user
    password  = options.password
    if password == None:
        password = getpass.getpass()
    #even though this is supposed to be an int already, in RHEL4's python it isn't
    bundleSize = int(options.bundle)

    #did you provide a data file?  If not we need all the information
    if options.data == None:
        version = options.version
        release = options.release
        update = options.update_level
        if update == "0":
            update = "gold"
        arch = options.arch
        extra = options.extra
        filename = getFileName(version,update, release, arch, extra)
        fullpath = DATA_DIR_RHN + filename
    else:
        filename = os.path.basename(options.data)
        fullpath = options.data
        version, release, update, arch, extra = parseDataFilename(filename)

    ssl = options.ssl
    clear = options.clear_channel

    if not os.path.exists(fullpath):
        print("Error: data file %s does not exist" % filename)
        sys.exit(1)

    parent = options.parent_channel
    clone = options.clone
    currentState = options.currentState

    newChannelLabel = options.dest_channel
    newChannelLabel = newChannelLabel.lower()
    srcChanLabel = options.source_channel

    if not options.dest_name:
        newChannelName = options.dest_channel
    else:
        newChannelName = options.dest_name

    if (srcChanLabel == None):
        print("You have not specified a source channel, we will try to determine it from inputs")
        srcChanLabel = findSrcChan(version, release, arch, extra)
    srcChanLabel = srcChanLabel.lower()

    print("Trying with source channel: %s" % srcChanLabel)

    # now lets login to the server
    proto = "http"
    if ssl:
        proto = proto + "s"
    client = Server("%s://%s/rpc/api" % (proto, ustr(server)))
    if clone and client.api.getVersion() < 5.1:
        print("--clone cannot be used with a Satellite version older than 5.1")
    try:
        auth = client.auth.login(user, password)
    except Fault:
        excInfo = sys.exc_info()[1]
        print("\n %s " % excInfo.faultString)
        sys.exit(1)

    skiplist = []
    if options.skiplist:
        skiplist = read_skip_list(options.skiplist)

    populate(client, auth, srcChanLabel, newChannelLabel, newChannelName, filename, parent, clear, skiplist=skiplist)


def populate(client, auth, srcChannel, newChannel, newChannelName, filename, parent, clear, skiplist=[]):

    chanList = client.channel.listSoftwareChannels(auth)

    src_id = None
    src_label = None
    new_id = None
    arch_label = None
    new_label = None
    for chan in chanList:
        label = getChannelAttr(chan, 'label')
        if label == srcChannel:
            src_label = getChannelAttr(chan, 'label')
            arch_label = getChannelAttr(chan, 'arch')
        elif label == newChannel:
            new_label = getChannelAttr(chan, 'label')

    if src_label == None:
        print("Error: Source Channel '%s' could not be found." % srcChannel)
        sys.exit(1)
    if new_label == None:
        if not clone:
            print("Creating channel, %s, with arch %s" % (newChannel, arch_label))
            client.channel.software.create(auth, newChannel, newChannelName, newChannel, getArchLabel(arch_label, client, auth), parent)
        if clone:
            print("Cloning channel %s to %s" % (srcChannel, newChannel))
            details_map = { 'name': newChannelName, 'label': newChannel, 'summary': newChannel }
            if parent:
                details_map['parent_label'] = parent
            #we must invert currentState, because API uses "True" for default state
            client.channel.software.clone(auth, srcChannel, details_map, not currentState)
        if parent:
            print("with parent %s" % parent)
    else:
        print("Reusing %s as destination channel" % newChannel)
    existing_packs = client.channel.software.listAllPackages(auth, srcChannel)

    if clear:
        clearChannel(newChannel, existing_packs, client, auth)

    # strip new lines
    fileList = list(map(str.strip, open(DATA_DIR_RHN + filename).readlines()))
    fileList.sort()
    pack_num = 0
    skip_num = 0
    ids_to_add = []

    print("%d packages in source file to push." % len(fileList))

    #for each filename, go through the src package list and find the package id
    for rpm in fileList:
        nvre_to_push = splitFilename(rpm)
        for src_pack in existing_packs:
            if (nvre_to_push == splitPackage(src_pack)):
                key = 'name'
                if not key in src_pack:
                    key = 'package_name'
                if src_pack[key] in skiplist:
                    skip_num = skip_num + 1
                    continue
                pack_num = pack_num + 1
                ids_to_add.append(getPackageAttr(src_pack, 'id'))
                break
    print("Pushing %s packages, please wait." % str(pack_num))
    if len(skiplist) > 0:
        print("Skipping %s packages based off of skip list" % str(skip_num))

    while len(ids_to_add) > 0:
        sys.stdout.write('%d of %d' % (pack_num-len(ids_to_add), pack_num))
        sys.stdout.flush()
        client.channel.software.addPackages(auth, newChannel, ids_to_add[:bundleSize])
        del ids_to_add[:bundleSize]
        sys.stdout.write('%s\r' % (' ' * 20))
        sys.stdout.flush()

    print("Successfully pushed %s packages out of %d" % (str(pack_num), len(fileList)))


def read_skip_list(filename):
    if not os.path.exists(filename):
        print("Specified skip list %s does not exist!" % filename)
        sys.exit(1)
    nameList = list(map(str.strip, open(filename).readlines()))
    print(nameList)
    return nameList


def clearChannel(label, packages, client, auth):
    ids_to_remove = []
    print("Clearing channel packages")
    for pack in packages:
        ids_to_remove.append(getPackageAttr(pack, 'id'))
    num_to_remove = len(ids_to_remove)
    while len(ids_to_remove) > 0:
        print('%d of %d' % (num_to_remove-len(ids_to_remove), num_to_remove))
        sys.stdout.flush()
        client.channel.software.removePackages(auth, label, ids_to_remove[:bundleSize])
        del ids_to_remove[:bundleSize]
        print('%s\r' % ' '*20)
        sys.stdout.flush()
    print("Finished clearing channel")



def getArchLabel(archName, client, auth):
    arches = client.channel.software.listArches(auth)
    for arch in arches:
        if getArchAttr(arch, 'name') == archName:
            return getArchAttr(arch, 'label')


def splitPackage(map):
    #   return [ getPackageAttr(map, 'name'), getPackageAttr(map, 'version'), getPackageAttr(map, 'release'), getPackageAttr(map, 'arch_label')]
    return '-'.join([getPackageAttr(map, 'name'), getPackageAttr(map, 'version'), getPackageAttr(map, 'release'), getPackageAttr(map, 'arch_label')])

def splitFilename(filename):
    per = filename.split('.')
    together = '.'.join(per[:-2]) + '-' + per[-2]
    #   return together.split('-')
    return together



def getAttr(map, base, attribute):
    label = map.get(attribute)
    if label == None:
        label = map.get(base + attribute)
    return label

def getPackageAttr(map, attribute):
    return getAttr(map, 'package_', attribute)

def getChannelAttr(map, attribute):
    return getAttr(map, 'channel_', attribute)

def getArchAttr(map, attribute):
    return getAttr(map, 'arch_', attribute)

#returns tuple (version, release, update, arch, extra)
def parseDataFilename(file):
    file = file.lower()
    split = file.split('-')
    if len(split) == 4:
        return (split[0], split[2], split[1], split[3], None)
    else:
        return (split[0], split[2], split[1], split[3], split[4])

def findSrcChan(version, release, arch, extra = None):
    low_release = release.lower()

    if version == "2.1":
        if low_release == "as":
            return "redhat-advanced-server-2.1"
        if release == "es":
            return "redhat-ent-linux-i386-es-2.1"
        if release == "ws":
            return "redhat-ent-linux-i386-ws-2.1"
    if low_release == "computenode":
        low_release = "hpc-node"
    #if we don't have a subrepo/extras we're done
    if extra == None:
        return "rhel-%s-%s-%s" % (arch, low_release, version)
    else:  #else we do, so lets process that
        low_extra = extra.lower()
        extra_map = {"clusterstorage": "cluster-storage", "resilientstorage": "rs",
                     "scalablefilesystem": "sfs", "highavailability": "ha", "loadbalancer": "lb"}
        if low_extra in extra_map:
            low_extra = extra_map[low_extra]
        if low_extra == "extras":
            return  "rhel-%s-%s-%s-%s" % (arch, low_release, version, low_extra)
        else:
            return "rhel-%s-%s-%s-%s" % (arch, low_release, low_extra, version )

#this function has no real use outside of red hat
def gather(clear):
    if clear:
        os.system("rm -rf " + DATA_DIR)

    if not path.isdir(DATA_DIR):
        os.mkdir(DATA_DIR)

    for version in versions:
        for update in updates:
            for release in releases:
                for arch in arches:
                    if float(version) <= 4:
                        repoDir = "RHEL-%s/%s/%s/%s/tree/" % (version, update, release, arch)
                        fullDir = GATHER_DIR + repoDir
                        saveRpmList(fullDir +  "RedHat/RPMS/", DATA_DIR + getFileName(version, update,release, arch))
                        #save extras
                        suppl = extras[0]
                        repoDir = "RHEL-%s-%s/%s/%s/%s/tree/" % (version, suppl, update, release, arch)
                        fullDir = GATHER_DIR + repoDir
                        saveRpmList(fullDir +  "RedHat/RPMS/", DATA_DIR + getFileName(version, update,release, arch, suppl))
                    elif float(version) == 5:
                        repoDir = "RHEL-%s-%s/%s/%s/os/" % (version, release, update, arch)
                        fullDir = GATHER_DIR + repoDir
                        saveRpmList(fullDir +  release + "/", DATA_DIR + getFileName(version, update,release, arch))
                        #now lets handle the other repos:
                        for subrepo in subrepos:
                            saveRpmList(fullDir +  subrepo + "/", DATA_DIR + getFileName(version, update,release, arch, subrepo))
                        #do supplementary
                        suppl = extras[1]
                        extraDir = "RHEL-%s-%s-%s/%s/%s/os/%s/" % (version, release, suppl, update, arch, suppl)
                        fullDir = GATHER_DIR + extraDir
                        saveRpmList(fullDir + "/", DATA_DIR + getFileName(version, update,release, arch, suppl))
                    elif float(version) == 6:
                        if update == "GOLD":  #RHEL 6 uses X.Y notation instead of X UY
                            update_num = "0"
                        else:
                            update_num = update.replace("U","") #strip the U
                        # RHEL-6/6.0/Server/x86_64/os/
                        repoDir = "RHEL-%s/%s.%s/%s/%s/os/" % (version, version, update_num, release, arch)
                        fullDir = GATHER_DIR + repoDir
                        if not saveRpmList(fullDir +  release + "/listing", DATA_DIR + getFileName(version, update,release, arch)):
                            if update_num < 6:
                                sourceDir = fullDir +  release + "/Packages/"
                            else:
                                sourceDir = fullDir +  "/Packages/"
                            saveRpmList(sourceDir, DATA_DIR + getFileName(version, update,release, arch))
                        #now lets handle the other repos:
                        for subrepo in subrepos:
                            if not saveRpmList(fullDir +  subrepo + "/listing", DATA_DIR + getFileName(version, update,release, arch, subrepo)):
                                saveRpmList(fullDir +  subrepo, DATA_DIR + getFileName(version, update,release, arch, subrepo))
                        #do supplementary
                        suppl = extras[1]
                        fullDir = GATHER_DIR + "RHEL-%s-%s/%s.%s/%s/%s/os/Packages/" % (version, suppl, version, update_num, release, arch)
                        saveRpmList(fullDir + "/", DATA_DIR + getFileName(version, update,release, arch, suppl))

                        #do optional
                        fullDir = GATHER_DIR + "RHEL-%s/%s.%s/%s/optional/%s/os/" % (version, version, update_num, release, arch)
                        saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch, 'optional'))

                    elif float(version) == 7:
                        if update == "GOLD":
                            update_num = "0"
                        else:
                            update_num = update.replace("U","") #strip the U
                        # RHEL-7/7.0/Server/x86_64/os/
                        repoDir = "RHEL-%s/%s.%s/%s/%s/os/" % (version, version, update_num, release, arch)
                        fullDir = GATHER_DIR + repoDir

                        saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch))

                        #now lets handle the other repos:
                        for subrepo in subrepos:
                            saveRpmList(fullDir +  '/addons/' + subrepo, DATA_DIR + getFileName(version, update,release, arch, subrepo))

                        #do supplementary
                        suppl = extras[1]
                        fullDir = GATHER_DIR + "Supp-RHEL-%s/%s.%s/%s/%s/os/Packages/" % (version, version, update_num, release, arch)
                        print(fullDir)
                        saveRpmList(fullDir + "/", DATA_DIR + getFileName(version, update,release, arch, suppl))

                        #do optional
                        fullDir = GATHER_DIR + "RHEL-%s/%s.%s/%s-optional/%s/os/" % (version, version, update_num, release, arch)
                        saveRpmList(fullDir + "/Packages/", DATA_DIR + getFileName(version, update,release, arch, 'optional'))


def saveRpmList(source, filename):
    if path.exists(source) and not path.exists(filename):
        print(source)
        if path.isfile(source):
            os.system("cp %s %s" % (source, filename))
        else: # isdir
            os.system("ls %s | grep rpm | sort > %s " % (source, filename));
        return True
    else:
        return False

def getFileName(version, update, release, arch, additional=None):
    if additional == None:
        return "%s-%s-%s-%s" % (version, update.lower(), release.lower(), arch)
    else:
        return "%s-%s-%s-%s-%s" % (version, update.lower(), release.lower(), arch, additional.capitalize())


def help():
    print("""
Examples
-------

Create channel 'my-stable-channel' for RHEL 6 Server GOLD x86_64:
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u gold -r server -a x86_64 -d my-stable-channel

or another way:
spacewalk-create-channel  -l admin -s myserver.example.com  -D /usr/share/rhn/channel-data/6-gold-server-x86_64 -d my-stable-channel

Upgrade prevously created channel 'my-stable-channel' to RHEL 6 Server u1 x86_64:
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u u1 -r server -a x86_64 -d my-stable-channel

Add the Supplementary channel as a child channel of 'my-stable-channel':
spacewalk-create-channel  -l admin -s myserver.example.com -v 6 -u u1 -r server -a x86_64 -e supplementary -P my-stable-channel
""")


def error(msg):
    print(msg)
    sys.exit(-1)


if __name__ == "__main__":
    main()
