# Adaptive loudness averaging.
# Hens Zimmerman, 02-05-2023.
# Python 3.

import matplotlib.pyplot as plt
import numpy
import os
import pyloudnorm
import regex
import scipy 
import soundfile
import sys
import warnings

# These next parameters for the stereo limiter are fixed for now.

delay         = 40      # samples
signal_length = 1       # second
release_coeff = 0.9995  # release time factor
attack_coeff  = 0.9     # attack time factor
block_length  = 1024    # samples

# End of fixed parameters.

class StereoLimiter:
    def __init__(self, attack_coeff, release_coeff, delay, threshold):
        self.delay_index = 0
        self.envelope_left = 0
        self.envelope_right = 0
        self.gain = 1
        self.gain_left = 1
        self.gain_right = 1
        self.delay = delay
        self.delay_line_left = numpy.zeros(delay)
        self.delay_line_right = numpy.zeros(delay)
        self.release_coeff = release_coeff
        self.attack_coeff = attack_coeff
        self.threshold = threshold

    def limit(self, signal):
        for idx, sample in enumerate(signal):
            self.delay_line_left[self.delay_index] = sample[0]
            self.delay_line_right[self.delay_index] = sample[1]

            self.delay_index = (self.delay_index + 1) % self.delay

            # Calculate an envelope of the signal.
            
            self.envelope_left  = max(abs(sample[0]), self.envelope_left * self.release_coeff)
            self.envelope_right  = max(abs(sample[1]), self.envelope_right * self.release_coeff)

            if self.envelope_left > self.threshold:
                target_gain_left = self.threshold / self.envelope_left
            else:
                target_gain_left = 1.0

            if self.envelope_right > self.threshold:
                target_gain_right = self.threshold / self.envelope_right
            else:
                target_gain_right = 1.0

            # Have self.gain go towards a desired limiter gain.
            
            self.gain_left = (self.gain_left * self.attack_coeff + target_gain_left * (1 - self.attack_coeff))
            self.gain_right = (self.gain_right * self.attack_coeff + target_gain_right * (1 - self.attack_coeff))

            # Gang stereo channels.

            self.gain = min(self.gain_left, self.gain_right)

            # Limit the delayed signal.
            
            signal[idx][0] = self.delay_line_left[self.delay_index] * self.gain
            signal[idx][1] = self.delay_line_right[self.delay_index] * self.gain
        return signal

# Suppress pyloudnorm warning about clipping.
# Since we compute in float64, we can fix this ourselves.

warnings.simplefilter("ignore")

# What command line args did we get?

arg_count = len(sys.argv)

if arg_count < 2:
    print("python dyn_adapt.py file div:xx loudness:-xx xfade:xx lower:xx max-up:x max-down:x oversample:x limit:-x")
    exit()

# Name of input file.

filename = sys.argv[1]

# Does this file exist at all?

if not os.path.isfile(filename):
    print(filename + " doesn't appear to exist\n")
    exit()

# Default division of file into blocks.

division = 10
seconds = False

# Default crossfade ratio into previous block.

xfade = 0.5

# Default target loudness.

final_loudness = -16.0

# Default no block loudness adaptation if loudness for block below a certain value.

lower = 12.0

# Default max upwards gain.

max_upwards = 6.0

# Default max downwards gain.

max_downwards = 6.0

# Oversampling factor.

oversampling = 4

# dBFS for limiter.

limit = -1

# Scan through optional arguments that override defaults
# div:10 div:10s loudness:-16 xfade:90 lower:12 max-up:6 max-down:6 oversample:4 limit:-2

if arg_count > 2:
    for idx in range(2, arg_count):
        arg = sys.argv[idx]

        match = regex.search(r"div:(\d+)", arg, regex.IGNORECASE)
        if match:
            division = int(match.group(1))

        match = regex.search(r"div:(\d+)s", arg, regex.IGNORECASE)
        if match:
            seconds = True

        match = regex.search(r"loudness:-(\d+)", arg, regex.IGNORECASE)
        if match:
            final_loudness = -int(match.group(1))

        match = regex.search(r"xfade:(\d+)", arg, regex.IGNORECASE)
        if match:
            xfade  = int(match.group(1)) / 100

        match = regex.search(r"lower:(\d+)", arg, regex.IGNORECASE)
        if match:
            lower = int(match.group(1))

        match = regex.search(r"max-up:(\d+)", arg, regex.IGNORECASE)
        if match:
            max_upwards = int(match.group(1))

        match = regex.search(r"max-down:(\d+)", arg, regex.IGNORECASE)
        if match:
            max_downwards = int(match.group(1))

        match = regex.search(r"oversample:(\d+)", arg, regex.IGNORECASE)
        if match:
            oversampling = int(match.group(1))

        match = regex.search(r"limit:-(\d+)", arg, regex.IGNORECASE)
        if match:
            limit = -int(match.group(1))


lower_threshold = final_loudness - lower

# Read entire file into ndarray.

audio, samplerate = soundfile.read(filename, frames=-1, dtype='float64', always_2d=True)

# Basic stats about file we got from soundfile.

samples = audio.shape[0]

# Is it a mono file or a multichannel file?

if len(audio.shape) > 1:
    channels = audio.shape[1]
    if channels > 2:
        print("Only stereo audio is currently supported")
        exit()
else:
    print("Mono files are not supported")
    exit()

# Division of file into blocks of size blocksize.
# If user supplied argument in seconds, divide into blocks of that many seconds.

if seconds:
    blocksize = division * samplerate
    division = int(samples / blocksize)
else:
    blocksize = int(samples / division)

print(str(blocksize))

# This leads to an integer size for the crossfade.

fadesize = int(blocksize * xfade)

# create BS.1770 meter

meter = pyloudnorm.Meter(samplerate) 

# Buffers to copy data back into.

new_audio = numpy.empty((0, channels))
sub_audio = numpy.empty((0, channels))
prev_audio = numpy.empty((0, channels))

for idx in range(0, division):
    # Create this block. Last block may have padding samples.
    print("Processing block {0} of {1}".format(idx + 1, division))

    start_idx = (idx * blocksize) - fadesize
    stop_idx = start_idx + blocksize + fadesize

    # First block does not require a crossfade section at the start.

    if start_idx < 0:
        start_idx = 0

    if idx == division - 1:
        sub_audio = audio[start_idx:]
    else:
        sub_audio = audio[start_idx:stop_idx]

    # Loudness adapt this block.

    loudness = meter.integrated_loudness(sub_audio)

    # Do not change "silent" portions of the mix.

    if loudness > lower_threshold:
        if loudness > final_loudness:
            delta = max(loudness - max_downwards, final_loudness)
        else:
            delta = min(loudness + max_upwards, final_loudness)
        sub_audio = pyloudnorm.normalize.loudness(sub_audio, loudness, delta)

        # This might issue a warning when we are correctly out of bounds [-1.0 .. 1.0]
        # Warning is suppressed so we check and correct for the digital clipping case here.

    # Crossfade into previous block.

    if idx > 0:
        for jdx in range(0, fadesize):
            mult = jdx * (1.0 / fadesize)
            inv_mult = 1.0 - mult

            for ch in range(0, channels):
                prev_audio[jdx + blocksize - fadesize][ch] = inv_mult * prev_audio[jdx + blocksize - fadesize][ch] + mult * sub_audio[jdx][ch]

    # Remove crossfade area at the beginning of this block, but not for first block.

    if idx > 0:
        sub_audio = sub_audio[fadesize:]

        # Append previous block to new_audio.

        new_audio = numpy.append(new_audio, prev_audio, axis = 0)

    # This block becomes previous block for next iteration.

    prev_audio = sub_audio

# Out of the loop we still need to concat the last block.

new_audio = numpy.append(new_audio, prev_audio, axis = 0)

# Gain scale final buffer to requested loudness norm.

loudness = meter.integrated_loudness(new_audio)

new_audio = pyloudnorm.normalize.loudness(new_audio, loudness, final_loudness)

peak_dB = 20.0 * numpy.log10(max(abs(numpy.min(new_audio)), numpy.max(new_audio)))

print("Sample peak at " + str(peak_dB) + " dBFS")

if oversampling > 1:
    print("Oversampling... hold on to your seats...")

    oversampled_new_audio = scipy.signal.resample(new_audio, samples * oversampling)

    peak_dB = 20.0 * numpy.log10(max(abs(numpy.min(oversampled_new_audio)), numpy.max(oversampled_new_audio)))

    print("Oversampled peak at " + str(peak_dB) + " dBFS")

# Remove extension from filename.

ext_length = 4

new_name = filename[:-ext_length] + '_new.wav'

if peak_dB > limit:
    threshold = 10**(limit / 20)
    limiter = StereoLimiter(attack_coeff, release_coeff, delay, threshold)
    limited_new_audio = limiter.limit(new_audio)
    soundfile.write(new_name, limited_new_audio, samplerate, 'PCM_24')
else:
    soundfile.write(new_name, new_audio, samplerate, 'PCM_24')