Custom Loss Functions
The HighDimSynthesizer can be extended with custom loss functions to improve specific aspects of synthetic data quality. This feature allows you to add domain-specific objectives to the training process, such as improving correlations between certain columns or enforcing specific data properties.
Overview
Custom loss functions are added to (not replacing) the default training loss, allowing you to steer the model toward specific objectives while maintaining overall synthesis quality.
Usage
Custom loss functions are passed to the learn() method via the custom_loss_fn parameter:
from synthesized import HighDimSynthesizer, MetaExtractor
df_meta = MetaExtractor.extract(df)
synth = HighDimSynthesizer(df_meta=df_meta)
synth.learn(
df_train=df,
num_iterations=1000,
custom_loss_fn=my_custom_loss, # Your custom loss function
custom_loss_weight=0.5, # Weight applied to the custom loss
)
The custom_loss_weight parameter controls how much influence the custom loss has relative to the default training loss. Higher values give more weight to your custom objective.
Function Signature
A custom loss function must have the following signature:
import tensorflow as tf
from typing import Dict
from synthesized.common.values import DataFrameValue
def custom_loss_fn(
input_dict: Dict[str, tf.Tensor], # Original input data by column
output_dict: Dict[str, tf.Tensor], # Model output logits by column
df_value: DataFrameValue, # Column metadata
) -> tf.Tensor: # Shape (batch_size,) or (batch_size, 1)
...
Parameters
input_dict-
A dictionary mapping column names to input tensors. Each tensor has shape
(batch_size, 1)and contains the original data values (encoded as integers for categorical columns). output_dict-
A dictionary mapping column names to output tensors. Each tensor contains the model’s output logits before sampling, with shape
(batch_size, 1, num_categories)for categorical columns. df_value-
A
DataFrameValueobject containing metadata about each column, such as the number of categories for categorical columns.
Example: Simple Regularization Loss
Here’s a simple example that adds L2 regularization to the output logits:
import tensorflow as tf
def regularization_loss(input_dict, output_dict, df_value):
"""Add L2 regularization to output logits."""
total = None
for name, output in output_dict.items():
flat = tf.reshape(output, (tf.shape(output)[0], -1))
loss = tf.reduce_mean(tf.square(flat), axis=-1)
if total is None:
total = loss
else:
total = total + loss
return total
synth.learn(df, custom_loss_fn=regularization_loss, custom_loss_weight=0.1)
Example: Accessing Column Metadata
You can use the df_value parameter to access information about columns:
from synthesized.common.values import CategoricalValue
def categorical_entropy_loss(input_dict, output_dict, df_value):
"""Encourage higher entropy in categorical predictions."""
total_loss = None
for name, value in df_value.items():
if isinstance(value, CategoricalValue):
logits = output_dict[name]
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
probs = tf.nn.softmax(logits_flat, axis=-1)
# Negative entropy (we want to maximize entropy, so minimize negative)
entropy = tf.reduce_sum(probs * tf.math.log(probs + 1e-8), axis=-1)
if total_loss is None:
total_loss = entropy
else:
total_loss = total_loss + entropy
if total_loss is None:
batch_size = tf.shape(list(output_dict.values())[0])[0]
return tf.zeros((batch_size,), dtype=tf.float32)
return total_loss
synth.learn(df, custom_loss_fn=categorical_entropy_loss, custom_loss_weight=0.5)
Best Practices
-
Start with small weights: Begin with a
custom_loss_weightof 0.1-0.5 and adjust based on results. Too high a weight may destabilize training. -
Return per-sample losses: Always return a tensor with one loss value per sample in the batch, not a scalar. This allows the framework to properly aggregate losses.
-
Handle edge cases: Check for empty dictionaries or missing columns to make your loss function robust.
-
Use TensorFlow operations: Ensure all computations use TensorFlow operations so gradients can flow through the custom loss.
-
Monitor training: Watch the
custom_lossmetric during training to ensure your custom objective is being optimized.
Monitoring Custom Loss
When a custom loss function is provided, an additional metric custom_loss is tracked during training. You can observe this metric in the training output to verify that your custom objective is being optimized:
synth.learn(df, custom_loss_fn=my_loss, custom_loss_weight=0.5, verbose=1)
# Output will include: total_loss, custom_loss, ...