class NBatchLogger(Callback):
"""
A Logger that log average performance per `display` steps.
"""
def __init__(self, display):
self.step = 0
self.display = display
self.metric_cache = {}
def on_batch_end(self, batch, logs={}):
self.step += 1
for k in self.params['metrics']:
if k in logs:
self.metric_cache[k] = self.metric_cache.get(k, 0) + logs[k]
if self.step % self.display == 0:
metrics_log = ''
for (k, v) in self.metric_cache.items():
val = v / self.display
if abs(val) > 1e-3:
metrics_log += ' - %s: %.4f' % (k, val)
else:
metrics_log += ' - %s: %.4e' % (k, val)
print('step: {}/{} ... {}'.format(self.step,
self.params['steps'],
metrics_log))
self.metric_cache.clear()
Creating custom Keras callbacks in python
Upasana | December 07, 2019 | 6 min read | 633 views
In this tutorial I am going to discuss how to create Custom callbacks i.e. logging batch results to stdout, stream batch results to CSV file, terminate training on NaN loss.
I was working with deep learning models using keras in python. Since there was not much variance coming in results per epoch, I wanted to see the results per batch size. That is when I came across Nbatchlogging here which goes something like:
You just need to copy this and paste in callbacks.py
file in keras package in your system and then you can use it in fit like:
nbatch_logging = NBatchLogger(display=1)
model.fit(X_train, y_train, validation_split = val_split,verbose=0,epochs=num_epochs, batch_size=batch_size,callbacks=[nbatch_logging])
You just need to make sure that verbose
is set to 0
such that logs per epoch and per batch doesn’t overlaps.
Now I could see the logs per batch. And then I became greedy. I wanted more.
Keras provide abstract class named Callback that we can extend to create custom callback implementation. Here is the Class Diagram for the same.
I wanted to save these logs as well, wanted to set an early callback as well if possible based on results per batches and then use it to make graphs as well. In keras callbacks file, there are six important functions to pay attention to as per one want to make a custom callback. Those are:
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
So we need to focus on just these callbacks and see what would we need in our case. For example, i will try to explain already made Keras callback
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered.
"""
def __init__(self):
super(TerminateOnNaN, self).__init__()
def on_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
print('Batch %d: Invalid loss, terminating training' % (batch))
self.model.stop_training = True
This callback makes sure that when Nan comes in results, model should stop training. And if we think logically, it should check batches as well not only results per epoch as it would be just time waste so that is what it works on as you can see. It checks results when batch ends and then let the model proceed as per the verification. It access the logs and then loss from it. Uses numpy to check if it is Nan or not and then stops training if Nan otherwise not.
To make a Batch Early Stopping callback class, i read the Early Stopping callback class and worked on it only. And it worked in just second try.
class BatchEarlyStopping(Callback):
def __init__(self, monitor='loss',
min_delta=0, patience=0, verbose=0, mode='auto'):
super(BatchEarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_batch = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('BatchEarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode,
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
self.wait = 0
self.stopped_batch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_batch_end(self, batch, logs=None):
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Batch Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_batch = batch
self.model.stop_training = True
def on_train_end(self, logs=None):
if self.stopped_batch > 0 and self.verbose > 0:
print('Batch %05d: early stopping' % (self.stopped_batch + 1))
Replacing few things worked like charm for me. We can use this by defining it like following before calling fit:
batch_early_callback = BatchEarlyStopping(patience=500,monitor='loss')
Note: Using Batch Early callback is a bit tricky as well as it depends on the batch size and size of the training samples as well.
Next is saving the batch logs in a file
class NBatchCSVLogger(Callback):
"""Callback that streams every batch results to a csv file.
"""
def __init__(self, filename, separator=',', append=False):
self.sep = separator
self.filename = filename
self.append = append
self.writer = None
self.keys = None
self.append_header = True
self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
super(NBatchCSVLogger, self).__init__()
def on_train_begin(self, logs=None):
if self.append:
if os.path.exists(self.filename):
with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
self.csv_file = open(self.filename, 'a' + self.file_flags)
else:
self.csv_file = open(self.filename, 'w' + self.file_flags)
def on_batch_end(self, batch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
if isinstance(k, six.string_types):
return k
elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k)))
else:
return k
if self.keys is None:
self.keys = sorted(logs.keys())
if self.model.stop_training:
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
if not self.writer:
class CustomDialect(csv.excel):
delimiter = self.sep
self.writer = csv.DictWriter(self.csv_file,
fieldnames=['batch'] + self.keys, dialect=CustomDialect)
if self.append_header:
self.writer.writeheader()
row_dict = OrderedDict({'batch': batch})
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
self.writer.writerow(row_dict)
self.csv_file.flush()
def on_train_end(self, logs=None):
self.csv_file.close()
self.writer = None
This callback saves logs for batches in a file and can come is handy to diagnose the variance in results.
This can be defined as following before calling fit
batch_logg_saving = NBatchCSVLogger("batch_logs.csv", separator=',', append=False)
This is how i defined these custom callbacks and few others as well. I hope this will help others well while defining custom callbacks.
So, we just need to see what we actually needs and which function would help us in getting that end result. This can help us because what we get is end results when epoch ends but sometimes our objective is something else. Lets say we want to save best weights when the change in validation accuracy and validation loss becomes constant. For that, as well we can put model training in try and put an exception such that model stops training when loss is almost constant for batches. Thanks for reading this article. I hope you found it useful.
Top articles in this category:
- Deploying Keras Model in Production using Flask
- Imbalanced classes in classification problem in deep learning with keras
- Python coding challenges for interviews
- Flask Interview Questions
- Deploying Keras Model in Production with TensorFlow 2.0
- Top 100 interview questions on Data Science & Machine Learning
- Find extra long factorials in python