Source code for tfwatcher.callbacks.predict_batch

from statistics import mean
from typing import Union

import tensorflow as tf

from ..firebase_helpers import random_char, write_in_callback


[docs]class PredictBatchEnd(tf.keras.callbacks.Callback): """This class is a subclass of the `tf.keras.callbacks.Callback <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback>`_ abstract base class and overrides the methods :func:`on_predict_batch_begin` and :func:`on_predict_batch_end` allowing loging after batches in ``predict`` method. This class also uses the :mod:`.firebase_helpers` to send data to Firebase Realtime database and also creates a 7 character unique string where the data is pushed on Firebase. Logging to Firebase is also controllable by ``schedule`` argument, even providing a granular control for each batch in ``predict`` methods. Example: .. code-block:: python :caption: Logging data after every batch in predict methods :emphasize-lines: 4,13 :linenos: import tfwatcher # here we specify schedule = 1 to log after every batch monitor_callback = tfwatcher.callbacks.PredictBatchEnd(schedule=1) model.compile( optimizer=..., loss=..., # metrics which will be logged metrics=[...], ) model.fit(..., callbacks=[monitor_callback]) .. warning:: If the ``steps_per_execution`` argument to compile in ``tf.keras.Model`` is set to N, the logging code will only be called every N batches. :param schedule: Use an integer value n to specify logging data every n batches the first one being logged by default. Use a list of integers to control logging with a greater granularity, logs on all batch numbers specified in the list taking the first batch as batch 1. Using a list will override loggging on the first batch by default, defaults to 1 :type schedule: Union[int, list[int]], optional :param round_time: This argument allows specifying if you want to see the times on the web-app to be rounded, in most cases you would not be using this, defaults to 2 :type round_time: int, optional :param print_logs: This argument should only be used when trying to debug if your logs do not appear in the web-app, if set to ``True`` this would print out the dictionary which is being pushed to Firebase, defaults to False :type print_logs: bool, optional :raises ValueError: If the ``schedule`` is neither an integer or a list. :raises Exception: If all the values in ``schedule`` list are not convertible to integer. """ def __init__( self, schedule: Union[int, list] = 1, round_time: int = 2, print_logs: bool = False, ): super(PredictBatchEnd, self).__init__() self.schedule = schedule self.start_time = None self.end_time = None self.times = list() self.round_time = round_time self.print_logs = print_logs self.ref_id = random_char(7) print(f"Use this ID to monitor training for this session: {self.ref_id}") self.is_int = False self.is_list = False if isinstance(self.schedule, int): self.is_int = True elif isinstance(self.schedule, list): self.is_list = True else: raise ValueError("schedule should either be an integer or a list") if self.is_list: try: self.schedule = list(map(int, self.schedule)) except (ValueError, TypeError) as err: raise Exception( "All elements in the list should be convertible to int: {}".format( err ) )
[docs] def on_predict_batch_begin(self, batch: int, logs: dict = None): """Overrides the `tf.keras.callbacks.Callback.on_predict_batch_begin <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_predict_batch_begin>`_ method which is called called at the beginning of a batch in predict methods. :param batch: Index of batch within the current epoch :type batch: int :param logs: contains the return value of ``model.predict_step``, it typically returns a dict with a key 'outputs' containing the model's outputs :type logs: dict, optional """ self.start_time = tf.timestamp()
[docs] def on_predict_batch_end(self, batch: int, logs: dict = None): """Overrides the `tf.keras.callbacks.Callback.on_predict_batch_end <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback#on_predict_batch_end>`_ method which is called called at the end of a batch in predict methods. This method adds the batch number, the average time taken and pushes it to Firebase using the :mod:`.firebase_helpers` module. :param epoch: Index of batch within the current epoch :type epoch: int :param logs: Aggregated metric results up until this batch, defaults to None :type logs: dict, optional """ self.end_time = tf.timestamp() # Use Python built in functions to allow using in @tf.function see # https://github.com/tensorflow/tensorflow/issues/27491#issuecomment-890887810 time = float(self.end_time - self.start_time) self.times.append(time) # Since we have similar logging code use the fact that if first argument of and is False Python doesn't # execute the second argument if ( (self.is_int and ((batch + 1) % self.schedule == 0)) or (self.is_list and ((batch + 1) in self.schedule)) ) or (batch == 0): data = { "batch": batch + 1, "epoch": False, "avg_time": round(mean(self.times), self.round_time), } write_in_callback(data=data, ref_id=self.ref_id) data["time"] = self.times if self.print_logs: print(data) self.times = list()