Source code for pygan.truesampler.gauss_true_sampler
# -*- coding: utf-8 -*-
import numpy as np
from pygan.true_sampler import TrueSampler
[docs]class GaussTrueSampler(TrueSampler):
'''
Sampler which draws samples from the `true` Gauss distribution.
'''
def __init__(self, mu, sigma, output_shape):
'''
Init.
Args:
mu: `float` or `array_like of floats`.
Mean (`centre`) of the distribution.
sigma: `float` or `array_like of floats`.
Standard deviation (spread or `width`) of the distribution.
output_shape: Output shape.
the shape is `(batch size, d1, d2, d3, ...)`.
'''
self.__mu = mu
self.__sigma = sigma
self.__output_shape = output_shape
[docs] def draw(self):
'''
Draws samples from the `true` distribution.
Returns:
`np.ndarray` of samples.
'''
return np.random.normal(loc=self.__mu, scale=self.__sigma, size=self.__output_shape)
[docs] def get_output_shape(self):
''' getter '''
return self.__output_shape
[docs] def set_output_shape(self, value):
''' setter '''
self.__output_shape = value
output_shape = property(get_output_shape, set_output_shape)