diff options
Diffstat (limited to 'bsie/extractor/text/summary.py')
-rw-r--r-- | bsie/extractor/text/summary.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/bsie/extractor/text/summary.py b/bsie/extractor/text/summary.py index cc8d90d..2c9efef 100644 --- a/bsie/extractor/text/summary.py +++ b/bsie/extractor/text/summary.py @@ -8,11 +8,11 @@ import transformers # bsie imports from bsie.extractor import base from bsie.matcher import nodes -from bsie.utils import bsfs, errors, ns +from bsie.utils import bsfs, ns # exports __all__: typing.Sequence[str] = ( - 'Language', + 'Summary', ) @@ -51,8 +51,8 @@ class Summary(base.Extractor): length_penalty=length_penalty, ) self._summarizer = transformers.pipeline( - "summarization", - model="joemgu/mlong-t5-large-sumstew", + 'summarization', + model='joemgu/mlong-t5-large-sumstew', ) def extract( @@ -60,17 +60,17 @@ class Summary(base.Extractor): subject: nodes.Entity, content: typing.Sequence[str], principals: typing.Iterable[bsfs.schema.Predicate], - ) -> typing.Iterator[typing.Tuple[nodes.Entity, bsfs.schema.Predicate, str]]: + ) -> typing.Iterator[typing.Tuple[nodes.Node, bsfs.schema.Predicate, str]]: # check predicates if self._predicate not in principals: return # preprocess - text = '\n'.join(content) - # generate summary - summaries = self._summarizer(text, **self._generator_kwargs) - if len(summaries) == 0: + text = '\n'.join(content).strip() + if len(text) == 0: return - # fetch summary, ignore title + # fetch summary + summaries = self._summarizer(text, **self._generator_kwargs) + assert len(summaries) == 1 prefix = 'Summary: ' title_and_summary = summaries[0]['summary_text'] summary = title_and_summary[title_and_summary.find(prefix) + len(prefix):] |