Source code for pygan.gansvaluefunction.margin_loss
# -*- coding: utf-8 -*-
import numpy as np
from pygan.gans_value_function import GANsValueFunction
[docs]class MarginLoss(GANsValueFunction):
'''
Value function in energy-based GANs framework.
References:
- Zhao, J., Mathieu, M., & LeCun, Y. (2016). Energy-based generative adversarial network. arXiv preprint arXiv:1609.03126.
'''
def __init__(
self,
margin=1.0,
margin_attenuate_rate=0.1,
attenuate_epoch=50
):
'''
Init.
Args:
margin: margin.
margin_attenuate_rate: Attenuate the `margin` by a factor of this value every `attenuate_epoch`.
attenuate_epoch: Attenuate the `margin` by a factor of `margin_attenuate_rate` every `attenuate_epoch`.
'''
self.__margin = margin
self.__margin_attenuate_rate = margin_attenuate_rate
self.__attenuate_epoch = attenuate_epoch
self.__epoch = 0
self.__discriminator_reward_list = []
[docs] def compute_discriminator_reward(
self,
true_posterior_arr,
generated_posterior_arr
):
'''
Compute discriminator's reward.
Args:
true_posterior_arr: `np.ndarray` of `true` posterior inferenced by the discriminator.
generated_posterior_arr: `np.ndarray` of `fake` posterior inferenced by the discriminator.
Returns:
`np.ndarray` of Gradients.
'''
grad_arr = true_posterior_arr + np.maximum(0, (self.__margin - generated_posterior_arr))
self.__epoch += 1
if self.__epoch % self.__attenuate_epoch == 0:
self.__margin = self.__margin * self.__margin_attenuate_rate
self.__discriminator_reward_list.append(grad_arr.mean())
return grad_arr
[docs] def compute_generator_reward(
self,
generated_posterior_arr
):
'''
Compute generator's reward.
Args:
generated_posterior_arr: `np.ndarray` of `fake` posterior inferenced by the discriminator.
Returns:
`np.ndarray` of Gradients.
'''
grad_arr = generated_posterior_arr
return grad_arr
[docs] def set_readonly(self, value):
''' setter '''
raise TypeError("This property must be read-only.")
[docs] def get_margin(self):
''' getter '''
return self.__margin
margin = property(get_margin, set_readonly)
[docs] def get_discriminator_reward_arr(self):
''' getter '''
return np.array(self.__discriminator_reward_list)
discriminator_reward_arr = property(get_discriminator_reward_arr, set_readonly)