aboutsummaryrefslogtreecommitdiffstats
path: root/bsie/extractor/text/summary.py
blob: cc8d90d69fc8c4b6871cbaf44c8d7ffe18bcea74 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

# standard imports
import typing

# external imports
import transformers

# bsie imports
from bsie.extractor import base
from bsie.matcher import nodes
from bsie.utils import bsfs, errors, ns

# exports
__all__: typing.Sequence[str] = (
    'Language',
    )


## code ##

class Summary(base.Extractor):
    """Extract a text summary.

    Uses the following summarization model:
        https://huggingface.co/Joemgu/mlong-t5-large-sumstew

    """

    CONTENT_READER = 'bsie.reader.document.Document'

    _predicate: bsfs.schema.Predicate

    _summarizer: transformers.pipelines.text2text_generation.SummarizationPipeline

    def __init__(
            self,
            max_length: int = 1024, # summary length in tokens
            num_beams: int = 4, # higher = better, but uses more memory
            length_penalty: float = 1.0, # higher = longer summaries
            ):
        super().__init__(bsfs.schema.from_string(base.SCHEMA_PREAMBLE + '''
            bse:summary rdfs:subClassOf bsfs:Predicate ;
                rdfs:domain bsn:Entity ;
                rdfs:range xsd:string ;
                bsfs:unique "true"^^xsd:boolean .
            '''))
        self._predicate = self.schema.predicate(ns.bse.summary)
        self._generator_kwargs = dict(
            max_length=max_length,
            num_beams=num_beams,
            length_penalty=length_penalty,
            )
        self._summarizer = transformers.pipeline(
            "summarization",
            model="joemgu/mlong-t5-large-sumstew",
            )

    def extract(
            self,
            subject: nodes.Entity,
            content: typing.Sequence[str],
            principals: typing.Iterable[bsfs.schema.Predicate],
            ) -> typing.Iterator[typing.Tuple[nodes.Entity, bsfs.schema.Predicate, str]]:
        # check predicates
        if self._predicate not in principals:
            return
        # preprocess
        text = '\n'.join(content)
        # generate summary
        summaries = self._summarizer(text, **self._generator_kwargs)
        if len(summaries) == 0:
            return
        # fetch summary, ignore title
        prefix = 'Summary: '
        title_and_summary = summaries[0]['summary_text']
        summary = title_and_summary[title_and_summary.find(prefix) + len(prefix):]
        yield subject, self._predicate, summary

## EOF ##