Source code for pygan.noisesampler.gauss_noise_sampler
# -*- coding: utf-8 -*-
import numpy as np
from pygan.noise_sampler import NoiseSampler
[docs]class GaussNoiseSampler(NoiseSampler):
'''
Generate samples based on the noise prior by Gauss distribution.
'''
def __init__(self, mu, sigma, output_shape):
'''
Init.
Args:
mu: `float` or `array_like of floats`.
Mean (`centre`) of the distribution.
sigma: `float` or `array_like of floats`.
Standard deviation (spread or `width`) of the distribution.
output_shape: Output shape.
the shape is `(batch size, d1, d2, d3, ...)`.
'''
self.__mu = mu
self.__sigma = sigma
self.__output_shape = output_shape
[docs] def generate(self):
'''
Generate noise samples.
Returns:
`np.ndarray` of samples.
'''
generated_arr = np.random.normal(loc=self.__mu, scale=self.__sigma, size=self.__output_shape)
if self.noise_sampler is not None:
self.noise_sampler.output_shape = generated_arr.shape
generated_arr += self.noise_sampler.generate()
return generated_arr
[docs] def get_output_shape(self):
''' getter '''
return self.__output_shape
[docs] def set_output_shape(self, value):
''' setter '''
self.__output_shape = value
output_shape = property(get_output_shape, set_output_shape)