#!/usr/bin/python3
#
# Univention Configuration Registry
#
# SPDX-FileCopyrightText: 2004-2026 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""Decode LDAP base64 encoded "key:: value" pairs read from standard input to standard output."""


import binascii
import re
import sys
import traceback
from datetime import datetime

import heimdal
from samba.dcerpc import dnsp, drsblobs
from samba.ndr import ndr_print, ndr_unpack


# Usage:
# ldbsearch -H /var/lib/samba/private/sam.ldb cn=<dc> supplementalCredentials | s4search-decode

krb5_context = None


class Krb5Context:
    def __init__(self):
        self.ctx = heimdal.context()
        self.etypes = self.ctx.get_permitted_enctypes()
        self.etype_ids = [et.toint() for et in self.etypes]


keytypes = {
    1: 'des_crc',
    3: 'des_md5',
    17: 'aes128',
    18: 'aes256',
}

regEx = re.compile('^([a-zA-Z0-9-]*):?: (.*)')


def decode_unicodePwd(value, kvno=0):
    global krb5_context
    if not krb5_context:
        krb5_context = Krb5Context()
    up_blob = binascii.a2b_base64(value)
    keyblock = heimdal.keyblock_raw(krb5_context.ctx, 23, up_blob)
    krb5key = heimdal.asn1_encode_key(keyblock, None, kvno)
    print("# decoded:")
    print("#\tsambaNTPassword:: %s" % binascii.b2a_hex(up_blob).upper().strip())
    print("# unicodePwd recoded as krb5key:")
    print("#\tkeytype: arcfour-hmac-md5")
    print("#\tkrb5Key:: %s" % binascii.b2a_base64(krb5key).strip())


def decode_ntPwdHistory(value, kvno=0):
    global krb5_context
    if not krb5_context:
        krb5_context = Krb5Context()
    hist_b64 = binascii.a2b_base64(value)
    hist_hex = binascii.hexlify(hist_b64).upper().strip()
    hist = [hist_hex[i: i + 32] for i in range(0, len(hist_hex), 32)]
    print("# decoded:")
    print("#\tntPwdHistory size: %s" % len(hist))
    for idx, hash in enumerate(hist):
        print("#\tntPwdHistory(idx_%s):: %s" % (idx, hash.decode('ASCII')))
    print("# ntPwdHistory recoded as krb5key:")
    print("#\tkeytype: arcfour-hmac-md5")
    idx = 0
    for idx, i in enumerate(range(0, len(hist_b64), 16)):
        keyblock = heimdal.keyblock_raw(krb5_context.ctx, 23, hist_b64[i:i + 16])
        krb5key = heimdal.asn1_encode_key(keyblock, None, kvno)
        print("#\tkrb5Key(idx_%s):: %s" % (idx, binascii.b2a_base64(krb5key).strip().decode('ASCII')))


def decode_krb5Key(value):
    global krb5_context
    if not krb5_context:
        krb5_context = Krb5Context()
    k = binascii.a2b_base64(value)
    (keyblock, salt, _kvno) = heimdal.asn1_decode_key(k)
    enctype = keyblock.keytype()
    enctype_id = enctype.toint()
    if enctype_id not in krb5_context.etype_ids:
        print("#\tSKIPPING ENC type %s, not support by this Heimdal version" % enctype_id)
        return
    print("#\tkrb5_keytype: %s (%d)" % (enctype, enctype_id))
    key_data = keyblock.keyvalue()
    print("#\tkeyblock: ", binascii.b2a_base64(key_data).strip())
    if enctype_id == 23:
        print("#\tas NThash:", binascii.b2a_hex(key_data).strip().upper())
    saltstring = salt.saltvalue()
    print("#\tsaltstring: ", saltstring)


def decode_supplementalCredentials(value, kvno=0):
    global krb5_context
    if not krb5_context:
        krb5_context = Krb5Context()
    object_data = ndr_unpack(drsblobs.supplementalCredentialsBlob, binascii.a2b_base64(value))
    print("# supplementalCredentials recoded as krb5key:")
    # print "%s" % (ndr_print(object_data).strip(),)
    print("# object_data.sub.num_packages:", object_data.sub.num_packages)
    for p in object_data.sub.packages:
        print("#\tsupplementalCredentials package name: ", p.name)
        krb_blob = binascii.unhexlify(p.data)
        try:
            krb = ndr_unpack(drsblobs.package_PrimaryKerberosBlob, krb_blob)
            print("#\tsupplementalCredentials package version: ", krb.version)
            print("#\tkrb5Salt: %s" % (krb.ctr.salt.string))
            for k in krb.ctr.keys:
                # print "#\tk.value:", binascii.b2a_base64(k.value).strip()
                keytype = keytypes.get(k.keytype, k.keytype)
                print("#\tkeytype: %s (%d)" % (keytype, k.keytype))
                print("#\tkeyblock:", end=' ')
                keyblock = heimdal.keyblock_raw(krb5_context.ctx, k.keytype, k.value)
                key_data = keyblock.keyvalue()
                print(binascii.b2a_base64(key_data).strip())
                print("#\tkrb5SaltObject:", end=' ')
                krb5SaltObject = heimdal.salt_raw(krb5_context.ctx, krb.ctr.salt.string)
                print(krb5SaltObject.saltvalue())
                krb5key = heimdal.asn1_encode_key(keyblock, krb5SaltObject, kvno)
                print("#\tkrb5Key:: %s" % binascii.b2a_base64(krb5key).strip())
        except Exception:
            print("### Exception during ndr_unpack")
            for line in traceback.format_exc().split('\n'):
                print("### %s" % line)


def decode_100nanosectimestamp(value):
    offset = 116444736000000000  # difference between 1601 and 1970
    return datetime.fromtimestamp((int(value) - offset) // 10000000).isoformat(' ')


def decode_accountExpires(value):
    print("# decoded (Note: timezone not converted to local time):")
    if int(value) == 0x7FFFFFFFFFFFFFFF:
        print("# accountExpires: never")
    else:
        try:
            print("# accountExpires: %s" % decode_100nanosectimestamp(value))
        except Exception:
            print("# traceback during decoding:")
            traceback.print_exc()


class decode_lastDate:

    def __init__(self, attrname):
        self.attrname = attrname

    def __call__(self, value):
        print("# decoded (Note: timezone not converted to local time):")
        if int(value) == 0:
            print("# %s: unknown" % self.attrname)
        else:
            try:
                print("# %s: %s" % (self.attrname, decode_100nanosectimestamp(value)))
            except Exception:
                print("# traceback during decoding:")
                traceback.print_exc()


def decode_pwdLastSet(value):
    print("# decoded (Note: timezone not converted to local time):")
    if int(value) == -1:
        print("# pwdLastSet: Just set")
    elif int(value) == 0:
        print("# pwdLastSet: Must change on next logon")
    else:
        try:
            print("# pwdLastSet: %s" % decode_100nanosectimestamp(value))
        except Exception:
            print("# traceback during decoding:")
            traceback.print_exc()


class decode_drsblob:

    def __init__(self, blobtype):
        self.blobtype = blobtype

    def __call__(self, value):
        object_data = ndr_unpack(self.blobtype, binascii.a2b_base64(value))
        print("# decoded:")
        for line in ndr_print(object_data).split('\n'):
            if line:
                print("# %s" % line)


objecttypes = {
    'dnsRecord': decode_drsblob(dnsp.DnssrvRpcRecord),
    'dNSProperty': decode_drsblob(dnsp.DnsProperty),
    'replPropertyMetaData': decode_drsblob(drsblobs.replPropertyMetaDataBlob),
    'repsFrom': decode_drsblob(drsblobs.repsFromToBlob),
    'repsTo': decode_drsblob(drsblobs.repsFromToBlob),
    'replUpToDateVector': decode_drsblob(drsblobs.replUpToDateVectorBlob),
    'supplementalCredentials': decode_supplementalCredentials,
    'unicodePwd': decode_unicodePwd,
    'ntPwdHistory': decode_ntPwdHistory,
    'krb5Key': decode_krb5Key,
    'pwdLastSet': decode_pwdLastSet,
    'accountExpires': decode_accountExpires,
    'badPasswordTime': decode_lastDate('badPasswordTime'),
    'lastLogoff': decode_lastDate('lastLogoff'),
    'lastLogon': decode_lastDate('lastLogon'),
}


def decode_line(line):
    res = regEx.search(line)
    if res:
        print(line, end=' ')
        attributeName = res.group(1)
        if attributeName in objecttypes:
            objecttypes[attributeName](res.group(2))
    else:
        print(line, end=' ')


try:
    line = sys.stdin.readline()
    while line != '':
        line2 = sys.stdin.readline()
        if line2[:1] == ' ':
            line = line[:-1] + line2[1:]
        else:
            decode_line(line)
            line = line2
except Exception:
    traceback.print_exc()
    sys.stdout.flush()
