Shortcuts

BaseCSVWriter

class torchtnt.framework.callbacks.BaseCSVWriter(header_row: List[str], dir_path: str, delimiter: str = '\t', filename: str = 'predictions.csv')

A callback to write prediction outputs to a CSV file.

This callback provides an interface to simplify writing outputs during prediction into a CSV file. This callback must be extended with an implementation for get_step_output_rows to write the desired outputs as rows in the CSV file.

By default, outputs at each step across all processes will be written into the same CSV file. The outputs in each row is a a list of strings, and should match the columns names defined in header_row.

Parameters:
  • header_row – columns of the CSV file
  • dir_path – directory path of where to save the CSV file
  • delimiter – separate columns in one row. Default is tab
  • filename – name of the file. Default filename is “predictions.csv”
on_exception(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None

Hook called when an exception occurs.

on_predict_end(state: State, unit: PredictUnit[TPredictData]) None

Hook called after prediction ends.

on_predict_start(state: State, unit: PredictUnit[TPredictData]) None

Hook called before prediction starts.

on_predict_step_end(state: State, unit: PredictUnit[TPredictData]) None

Hook called after a predict step ends.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources