Source code for skbel.preprocessing.dct

#  Copyright (c) 2021. Robin Thibaut, Ghent University

"""Discrete Cosine Transform written in scikit-learn style."""

import numpy as np
from scipy.fftpack import dct, idct
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import check_array

__all__ = ["DiscreteCosineTransform2D"]


def dct2(a):
    """Performs a 2D Discrete Cosine Transform (DCT) on a 2D array."""
    return dct(dct(a.T, norm="ortho").T, norm="ortho")


def idct2(a):
    """Performs a 2D Inverse Discrete Cosine Transform (IDCT) on a 2D array."""
    return idct(idct(a.T, norm="ortho").T, norm="ortho")


[docs] class DiscreteCosineTransform2D(TransformerMixin, BaseEstimator):
[docs] def __init__(self, *, m_cut: int = None, n_cut: int = None): # Original shape self.n_rows = None self.n_cols = None # Number of components to keep self.m_cut = m_cut self.n_cut = n_cut
[docs] def fit(self, X, y): return self
[docs] def transform(self, X): try: X = check_array(X, allow_nd=True) except ValueError: X = check_array(X.reshape(1, -1)) self.n_rows = X.shape[1] self.n_cols = X.shape[2] if self.m_cut is None: self.m_cut = self.n_rows if self.n_cut is None: self.n_cut = self.n_cols X_dct = np.array([dct2(e)[: self.m_cut, : self.n_cut] for e in X]) X_dct = X_dct.reshape((X_dct.shape[0], -1)) return X_dct
[docs] def inverse_transform(self, X): try: X = check_array(X, allow_nd=True) except ValueError: X = check_array(X.reshape(1, -1)) X = X.reshape(-1, self.m_cut, self.n_cut) dummy = np.zeros((X.shape[0], self.n_rows, self.n_cols)) dummy[:, : self.m_cut, : self.n_cut] = X X_ivt = np.array([idct2(e) for e in dummy]) return X_ivt
[docs] def fit_transform(self, X, y=None, **fit_params): return self.fit(X, y).transform(X)