Skip to content

Commit

Permalink
Allow Ensembler class work with models that returns list of tensors a…
Browse files Browse the repository at this point in the history
…s output
  • Loading branch information
BloodAxe committed Nov 3, 2023
1 parent b650f35 commit 6556bc8
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def forward(self, *input, **kwargs): # skipcq: PYL-W0221
keys = self.outputs
elif isinstance(outputs[0], dict):
keys = outputs[0].keys()
output_is_dict = True
elif isinstance(outputs[0], (list, tuple)):
keys = list(range(len(outputs[0])))
output_is_dict = False
elif torch.is_tensor(outputs[0]):
keys = None
else:
Expand All @@ -104,12 +108,15 @@ def forward(self, *input, **kwargs): # skipcq: PYL-W0221
predictions = _deaugment_averaging(predictions, self.reduction)
averaged_output = predictions
else:
averaged_output = {}
averaged_output = {} if output_is_dict else []
for key in keys:
predictions = [output[key] for output in outputs]
predictions = torch.stack(predictions)
predictions = _deaugment_averaging(predictions, self.reduction)
averaged_output[key] = predictions
if output_is_dict:
averaged_output[key] = predictions
else:
averaged_output.append(predictions)

return averaged_output

Expand Down

0 comments on commit 6556bc8

Please sign in to comment.