#!/usr/bin/python

# Install an arbitrary set of RPMs on an Atomic Host.
# Typical usage is e.g. kernel-debug.

from __future__ import print_function

import os
import sys
import subprocess
import hashlib
import json
import stat
import argparse
import tempfile
import shutil
from gi.repository import GLib, Gio, OSTree

def fatal(msg):
    sys.stderr.write(msg + '\n')
    sys.exit(1)

def print_running(args):
    if isinstance(args, str) or isinstance(args, unicode):
        argstr = args
    else:
        argstr = subprocess.list2cmdline(args)
    print("Running: {0}".format(argstr))

def run_sync(args, **kwargs):
    """Wraps subprocess.check_call(), logging the command line too."""
    print_running(args)
    subprocess.check_call(args, **kwargs)

def run_sync_get_rc(args, **kwargs):
    """Wraps subprocess.call(), logging the command line too."""
    print_running(args)
    return subprocess.call(args, **kwargs)

parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]))
parser.add_argument('pkgs', metavar='PKG', type=str, nargs='+',
                   help='An RPM package')
parser.add_argument('--first-deployment', action='store_true',
                    help='Operate on the first deployment, not the booted one (for debugging)')
args = parser.parse_args()

UNDO_MESSAGE = """
You will need to reboot for the update to take effect.  To
undo, use `atomic host rollback`, and reboot again.
"""

def stbuf_equal(a, b):
    return (a.st_dev == b.st_dev and
            a.st_ino == b.st_ino)

def update_checksum_from_stream(csum, f):
    b = f.read(8192)
    while b != '':
        csum.update(b)
        b = f.read(8192)

def find_kver(rootdir):
     moddir = rootdir + '/usr/lib/modules'
     for fname in os.listdir(moddir):
         path = moddir + '/' + fname
         if not os.path.isdir(path):
             continue
         return fname
     fatal("Failed to determine kernel version from /usr/lib/modules")

def get_kernel_path(rootdir):
    bootdir = rootdir + '/boot'
    for fname in os.listdir(bootdir):
        path = bootdir + '/' + fname
        if fname.startswith('vmlinuz'):
            return path

sysroot_path = '/host'
sysroot = OSTree.Sysroot.new(Gio.File.new_for_path(sysroot_path))
sysroot.load(None)
_,repo = sysroot.get_repo(None)
host_root_stbuf = os.stat('/host')
sysroot.load(None)
target_deployment = None
if args.first_deployment:
    target_deployment = sysroot.get_deployments()[0]
else:
    for deployment in sysroot.get_deployments():
        dpath = sysroot_path + '/' + sysroot.get_deployment_dirpath(deployment)
        stbuf = os.stat(dpath)
        if stbuf_equal(host_root_stbuf, stbuf):
            target_deployment = deployment
            break
if target_deployment is None:
    fatal("error: Failed to detect booted deployment")

target_osname = target_deployment.get_osname()
target_commit = target_deployment.get_csum()
print("Target deployment: {0} {1}".format(target_osname, target_commit))

target_dpath = sysroot_path + '/' + sysroot.get_deployment_dirpath(target_deployment)
target_rpmdb_path = target_dpath + '/usr/share/rpm'
treefile_path = target_dpath + '/usr/share/rpm-ostree/treefile.json'
with open(treefile_path) as f:
    treefile = json.load(f)

# https://bugzilla.redhat.com/show_bug.cgi?id=1293718#c14
altfiles_path = '/usr/lib64/libnss_altfiles.so.2'
if not os.path.exists(altfiles_path):
    print("Copying host {}".format(altfiles_path))
    shutil.copy('/host' + altfiles_path, altfiles_path)
    nsswitch_conf = '/etc/nsswitch.conf'
    shutil.copy('/host' + nsswitch_conf, nsswitch_conf)

# Note that we use the host's /var/tmp because
# we want hardlinks from /host/ostree/repo to work.
# Except we actually traverse the real path instead of /host/var because
# that's a bind mount, and wouldn't let us make hard links.
workdir = tempfile.mkdtemp(dir='/host/ostree/deploy/{0}/var/tmp'.format(target_osname))
try:
    checkout_tmp = workdir + '/checkout'
    options = OSTree.RepoCheckoutOptions()
    print("Preparing root")
    repo.checkout_tree_at(options, -1, checkout_tmp, target_commit, None)

    # Now, break hardlinks for the rpmdb, since RPM does writes in place
    print("Making rpmdb read-write")
    rpmdb_path = checkout_tmp + '/usr/share/rpm'
    rpmdb_path_tmp = rpmdb_path + '.tmp'
    os.rename(rpmdb_path, rpmdb_path_tmp)
    subprocess.check_call(['cp', '-a', rpmdb_path_tmp, rpmdb_path])
    shutil.rmtree(rpmdb_path_tmp)

    # Temporarily make /etc writable again
    os.rename(checkout_tmp + '/usr/etc', checkout_tmp + '/etc')

    includes_kernel = False
    for pkg in args.pkgs:
        if os.path.basename(pkg).startswith('kernel-'):
            includes_kernel = True
            print("Detected kernel- package in set")
            break

    if includes_kernel:
        print("Removing non-debug kernel in temporary root")
        rpm_opts = ['rpm', '--quiet', '--dbpath', '/usr/share/rpm', '--root', checkout_tmp, '--noscripts', '--notriggers', '--nodeps']
        run_sync(rpm_opts + ['--erase', 'kernel']) 

        # Now prepare the boot locations
        checkout_tmp_boot = checkout_tmp + '/boot'
        shutil.rmtree(checkout_tmp_boot)
        # Blow away leftover module files
        checkout_tmp_modules = checkout_tmp + '/lib/modules'
        for child in os.listdir(checkout_tmp_modules):
            shutil.rmtree(checkout_tmp_modules + '/'+ child)
        os.mkdir(checkout_tmp_boot)
        checkout_tmp_usrlib_boot = checkout_tmp + '/usr/lib/ostree-boot'
        if os.path.isdir(checkout_tmp_usrlib_boot):
            shutil.rmtree(checkout_tmp_usrlib_boot)

    child_argv = ['rpm', '--dbpath', '/usr/share/rpm', '--root', checkout_tmp, '--replacepkgs', '-U', ]
    run_sync(child_argv + args.pkgs)

    if includes_kernel:
        print("Preparing kernel and initramfs")
        kver = find_kver(checkout_tmp)
        run_sync(['chroot', checkout_tmp, 'depmod', kver])
        run_sync(['mount', '-t' 'proc', 'proc', checkout_tmp + '/proc'])
        try:
            dracut_argv = ['chroot', checkout_tmp, 'dracut', '-v', '--tmpdir=/boot', '-f', '/boot/initramfs.img', '--kver', kver]
            dracut_argv.extend(treefile.get('initramfs-args', []))
            run_sync(dracut_argv)
        finally:
            run_sync(['umount', checkout_tmp + '/proc'])

    # Make /etc back into ro defaults (note has to be after dracut, as
    # it wants to pick up files from /etc)
    os.rename(checkout_tmp + '/etc', checkout_tmp + '/usr/etc')

    if includes_kernel:
        # Hash the kernel + initramfs together to produce the "bootcsum"
        bootcsum = hashlib.sha256()
        initramfs_tmppath = checkout_tmp + '/boot/initramfs.img'
        with open(initramfs_tmppath) as f:
            update_checksum_from_stream(bootcsum, f)
        kpath = get_kernel_path(checkout_tmp)
        with open(kpath) as f:
            update_checksum_from_stream(bootcsum, f)
        os.rename(kpath, checkout_tmp + '/boot/vmlinuz-{0}'.format(bootcsum.hexdigest()))
        os.rename(initramfs_tmppath, checkout_tmp + '/boot/initramfs-{0}'.format(bootcsum.hexdigest()))

        # Reimplementing RPM_OSTREE_POSTPROCESS_BOOT_LOCATION
        os.rename(checkout_tmp + '/boot', checkout_tmp + '/usr/lib/ostree-boot')
        os.mkdir(checkout_tmp + '/boot')

    print("Committing to local OSTree repository...")
    fd,origin_path = tempfile.mkstemp(dir='/host/tmp')
    with os.fdopen(fd, 'w') as f:
        f.write("""[origin]
from-commit={0}
packages={1}
""".format(target_commit, " ".join(map(os.path.basename, args.pkgs))))
    temp_local = 'temp-local'
    try:
        run_sync(['ostree', '--repo=/host/ostree/repo', 'commit', '--link-checkout-speedup', '--add-metadata-string=version=local', '-s', 'Tree with local packages', '-m', 'Generated from {0}'.format(target_commit), '-b', temp_local], cwd=checkout_tmp)
        print("Deploying new commit...")
        run_sync(['nsenter', '-t', '1', '--mount', '--', 'ostree', 'admin', 'deploy', '--os=' + target_osname, '--origin-file=' + origin_path.replace('/host',''), temp_local])
    finally:
        rc = run_sync_get_rc(['nsenter', '-t', '1', '--mount', '--', 'ostree', 'refs', '--delete', temp_local])
        if rc != 0:
            print("WARNING: failed to delete temp-local ref")
    print("")
    print("Success!")
    print(UNDO_MESSAGE)
finally:
    if not ('PRESERVE_TEMP' in os.environ):
        shutil.rmtree(workdir)
    else:
        print("Saved temporary state in {0}".format(workdir))
