BaseCSVWriter¶
-
class
torchtnt.framework.callbacks.
BaseCSVWriter
(header_row: List[str], dir_path: str, delimiter: str = '\t', filename: str = 'predictions.csv')¶ A callback to write prediction outputs to a CSV file.
This callback provides an interface to simplify writing outputs during prediction into a CSV file. This callback must be extended with an implementation for
get_step_output_rows
to write the desired outputs as rows in the CSV file.By default, outputs at each step across all processes will be written into the same CSV file. The outputs in each row is a a list of strings, and should match the columns names defined in
header_row
.Parameters: - header_row – columns of the CSV file
- dir_path – directory path of where to save the CSV file
- delimiter – separate columns in one row. Default is tab
- filename – name of the file. Default filename is “predictions.csv”
-
on_exception
(state: State, unit: Union[TrainUnit[TTrainData], EvalUnit[TEvalData], PredictUnit[TPredictData]], exc: BaseException) None ¶ Hook called when an exception occurs.
-
on_predict_end
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called after prediction ends.
-
on_predict_start
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called before prediction starts.
-
on_predict_step_end
(state: State, unit: PredictUnit[TPredictData]) None ¶ Hook called after a predict step ends.