sciwing.infer.classification

BaseClassificationInference

class sciwing.infer.classification.BaseClassificationInference.BaseClassificationInference(model: <sphinx.ext.autodoc.importer._MockObject object at 0x7f2cb7b2d5d0>, model_filepath: str, datasets_manager: sciwing.data.datasets_manager.DatasetsManager, device: Union[str, <sphinx.ext.autodoc.importer._MockObject object at 0x7f2cb7b2de90>, None] = <sphinx.ext.autodoc.importer._MockObject object>)

Bases: object

Abstract Base Class for Classification Inference.The BaseClassification Inference provides a skeleton for concrete classes that would want to perform inference for a text classification task.

__init__(model: <sphinx.ext.autodoc.importer._MockObject object at 0x7f2cb7b2d5d0>, model_filepath: str, datasets_manager: sciwing.data.datasets_manager.DatasetsManager, device: Union[str, <sphinx.ext.autodoc.importer._MockObject object at 0x7f2cb7b2de90>, None] = <sphinx.ext.autodoc.importer._MockObject object>)
Parameters:
  • model (nn.Module) – A pytorch module
  • model_filepath (str) – The path where the parameters for the best models are stored. This is usually the best_model.pt while in an experiment directory
  • datasets_manager (DatasetsManager) – Any dataset that conforms to the pytorch Dataset specification
  • device (Optional[Union[str, torch.device]]) – This is either a string like cpu, cuda:0 or a torch.device object
get_misclassified_sentences(true_label_idx: int, pred_label_idx: int) → List[str]
get_true_label_indices_names(labels: List[sciwing.data.label.Label]) -> (typing.List[int], typing.List[str])

Given an list of labels, it returns the indices and the names of the label

Parameters:labels (Dict[str, Any]) – iter_dict returned by a dataset
Returns:List of integers that represent the true class List of strings that represent the true class
Return type:(List[int], List[str])
infer_batch(lines: List[sciwing.data.line.Line])
load_model()

Loads the best_model from the model_filepath.

model_forward_on_lines(lines: List[sciwing.data.line.Line])

Perform the model forward pass given an iter_dict

Parameters:lines (List[Line]) –
model_output_dict_to_prediction_indices_names(model_output_dict: Dict[str, Any]) -> (typing.List[int], typing.List[str])

Given an model_output_dict, it returns the predicted class indices and names

Parameters:model_output_dict (Dict[str, Any]) – output dictionary from a model
Returns:List of integers that represent the predicted class List of strings that represent the predicted class
Return type:(List[int], List[str])
on_user_input(line: sciwing.data.line.Line)
print_confusion_matrix()
report_metrics()

Reports the metrics for returning the dataset

run_inference() → Dict[str, Any]

Should Run inference on the test dataset

This method should run the model through the test dataset. It should perform inference and collect the appropriate metrics and data that is necessary for further use

Returns:Returns
Return type:Dict[str, Any]
run_test()

Classification Inference

class sciwing.infer.classification.classification_inference.ClassificationInference(model: <sphinx.ext.autodoc.importer._MockObject object at 0x7f2cb7b2de50>, model_filepath: str, datasets_manager: sciwing.data.datasets_manager.DatasetsManager, tokens_namespace: str = 'tokens', normalized_probs_namespace: str = 'normalized_probs', device: str = 'cpu')

Bases: sciwing.infer.classification.BaseClassificationInference.BaseClassificationInference

The sciwing engine runs the test lines through the classifier and returns the predictions/probabilities for different classes At a later point in time this method should be able to take any context of lines (may be from a file) and produce the output.

This class also helps in performing various interactions with the results on the test dataset. Some features are 1) Show confusion matrix 2) Investigate a particular example in the test dataset 3) Get instances that were classified as 2 when their true label is 1 and others

All it needs is the configuration file stored under every experiment to have a vocab already stored in the experiment folder

generate_report_for_paper()

Generates just the fscore to be used in reporting on print

get_misclassified_sentences(true_label_idx: int, pred_label_idx: int)

This returns the true label misclassified as pred label idx

Parameters:
  • true_label_idx (int) – The label index of the true class name
  • pred_label_idx (int) – The label index of the predicted class name
Returns:

A list of strings where the true class is classified as pred class.

Return type:

List[str]

get_true_label_indices_names(labels: List[sciwing.data.label.Label]) -> (typing.List[int], typing.List[str])

Given an list of labels, it returns the indices and the names of the label

Parameters:labels (Dict[str, Any]) – iter_dict returned by a dataset
Returns:List of integers that represent the true class List of strings that represent the true class
Return type:(List[int], List[str])
infer_batch(lines: List[str]) → List[str]

Runs inference on a batch of lines This method can be used for applications. When APIS are being developed to serve over the web or when terminal applications are being written to read from files and infer, this method comes in handy

Parameters:lines (List[str]) – List of text spans to be infered
Returns:Reutrns the class names for all the sentences in the input
Return type:List[str]
model_forward_on_lines(lines: List[sciwing.data.line.Line])

Perform the model forward pass given an iter_dict

Parameters:lines (List[Line]) –
model_output_dict_to_prediction_indices_names(model_output_dict: Dict[str, Any]) -> (typing.List[int], typing.List[str])

Given an model_output_dict, it returns the predicted class indices and names

Parameters:model_output_dict (Dict[str, Any]) – output dictionary from a model
Returns:List of integers that represent the predicted class List of strings that represent the predicted class
Return type:(List[int], List[str])
on_user_input(line: str) → str

Runs the inference when the user inputs a single sentence either on the terminal or some other application

Parameters:line (str) – The line entered by the user
Returns:The class label that is infered for the user input
Return type:str
print_confusion_matrix() → None

Prints the confusion matrix for the test dataset

report_metrics()

Reports the metrics for returning the dataset

run_inference() → Dict[str, Any]

Should Run inference on the test dataset

This method should run the model through the test dataset. It should perform inference and collect the appropriate metrics and data that is necessary for further use

Returns:Returns
Return type:Dict[str, Any]
run_test()

Runs inference and reports test metrics