aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/extractor/text/summary.py
diff options
context:
space:
mode:
Diffstat (limited to 'bsie/extractor/text/summary.py')
-rw-r--r--bsie/extractor/text/summary.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/bsie/extractor/text/summary.py b/bsie/extractor/text/summary.py
index cc8d90d..2c9efef 100644
--- a/bsie/extractor/text/summary.py
+++ b/bsie/extractor/text/summary.py
@@ -8,11 +8,11 @@ import transformers
# bsie imports
from bsie.extractor import base
from bsie.matcher import nodes
-from bsie.utils import bsfs, errors, ns
+from bsie.utils import bsfs, ns
# exports
__all__: typing.Sequence[str] = (
- 'Language',
+ 'Summary',
)
@@ -51,8 +51,8 @@ class Summary(base.Extractor):
length_penalty=length_penalty,
)
self._summarizer = transformers.pipeline(
- "summarization",
- model="joemgu/mlong-t5-large-sumstew",
+ 'summarization',
+ model='joemgu/mlong-t5-large-sumstew',
)
def extract(
@@ -60,17 +60,17 @@ class Summary(base.Extractor):
subject: nodes.Entity,
content: typing.Sequence[str],
principals: typing.Iterable[bsfs.schema.Predicate],
- ) -> typing.Iterator[typing.Tuple[nodes.Entity, bsfs.schema.Predicate, str]]:
+ ) -> typing.Iterator[typing.Tuple[nodes.Node, bsfs.schema.Predicate, str]]:
# check predicates
if self._predicate not in principals:
return
# preprocess
- text = '\n'.join(content)
- # generate summary
- summaries = self._summarizer(text, **self._generator_kwargs)
- if len(summaries) == 0:
+ text = '\n'.join(content).strip()
+ if len(text) == 0:
return
- # fetch summary, ignore title
+ # fetch summary
+ summaries = self._summarizer(text, **self._generator_kwargs)
+ assert len(summaries) == 1
prefix = 'Summary: '
title_and_summary = summaries[0]['summary_text']
summary = title_and_summary[title_and_summary.find(prefix) + len(prefix):]