Source code for pygan.truesampler.image_true_sampler

# -*- coding: utf-8 -*-
import numpy as np
from pygan.true_sampler import TrueSampler
from pydbm.cnn.featuregenerator.image_generator import ImageGenerator


[docs]class ImageTrueSampler(TrueSampler): ''' Sampler which draws samples from the `true` distribution of images. ''' def __init__( self, batch_size, image_dir, seq_len=None, gray_scale_flag=True, wh_size_tuple=(100, 100), norm_mode="z_score" ): ''' Init. Args: batch_size: Batch size. image_dir: Dir path which stores image files. seq_len: The length of one sequence. gray_scale_flag: Gray scale or not(RGB). wh_size_tuple: Tuple(`width`, `height`). norm_mode: How to normalize pixel values of images. - `z_score`: Z-Score normalization. - `min_max`: Min-max normalization. - `tanh`: Normalization by tanh function. ''' self.__feature_generator = ImageGenerator( epochs=1, batch_size=batch_size, training_image_dir=image_dir, test_image_dir=image_dir, seq_len=seq_len, gray_scale_flag=gray_scale_flag, wh_size_tuple=wh_size_tuple, norm_mode=norm_mode ) self.__norm_mode = norm_mode self.__seq_len = seq_len
[docs] def draw(self): ''' Draws samples from the `true` distribution. Returns: `np.ndarray` of samples. ''' observed_arr = None for result_tuple in self.__feature_generator.generate(): observed_arr = result_tuple[0] break observed_arr = observed_arr.astype(float) if self.__norm_mode == "z_score": if observed_arr.std() != 0: observed_arr = (observed_arr - observed_arr.mean()) / observed_arr.std() elif self.__norm_mode == "min_max": if (observed_arr.max() - observed_arr.min()) != 0: observed_arr = (observed_arr - observed_arr.min()) / (observed_arr.max() - observed_arr.min()) elif self.__norm_mode == "tanh": observed_arr = np.tanh(observed_arr) return observed_arr
[docs] def get_seq_len(self): ''' getter ''' return self.__seq_len
[docs] def set_readonly(self, value): ''' setter ''' raise TypeError()
seq_len = property(get_seq_len, set_readonly)