Source code for pygan.truesampler.conditionaltruesampler.conditional_image_true_sampler

# -*- coding: utf-8 -*-
import numpy as np
from pygan.truesampler.conditional_true_sampler import ConditionalTrueSampler
from pygan.truesampler.image_true_sampler import ImageTrueSampler


[docs]class ConditionalImageTrueSampler(ConditionalTrueSampler): ''' Sampler which draws samples from the conditional `true` distribution of images. ''' # The axis along which the arrays will be joined conditions and generated data. __conditional_axis = 1 def __init__(self, image_true_sampler): ''' Init. Args: image_true_sampler: is-a `ImageTrueSampler`. ''' if isinstance(image_true_sampler, ImageTrueSampler) is False: raise TypeError() self.__image_true_sampler = image_true_sampler
[docs] def draw(self): ''' Draws samples from the `true` distribution. Returns: `np.ndarray` of samples. ''' observed_arr = self.__image_true_sampler.draw() observed_arr = self.add_condition(observed_arr) return observed_arr
[docs] def add_condition(self, observed_arr): ''' Add condtion. Args: observed_arr: `np.ndarray` of samples. Returns: `np.ndarray` of samples. ''' if self.__image_true_sampler.seq_len is None: condition_arr = self.__image_true_sampler.draw() return np.concatenate((observed_arr, condition_arr), axis=self.conditional_axis) else: return np.concatenate( ( observed_arr[:, 0, :, :, :], observed_arr[:, 1, :, :, :] ), axis=self.conditional_axis )
[docs] def get_conditional_axis(self): ''' getter ''' return self.__conditional_axis
[docs] def set_conditional_axis(self, value): ''' setter ''' self.__conditional_axis = value
conditional_axis = property(get_conditional_axis, set_conditional_axis)