"""This module contains the Reader class."""
from .builtin_datasets import BUILTIN_DATASETS
[docs]
class Reader:
"""The Reader class is used to parse a file containing ratings.
Such a file is assumed to specify only one rating per line, and each line
needs to respect the following structure: ::
user ; item ; rating ; [timestamp]
where the order of the fields and the separator (here ';') may be
arbitrarily defined (see below). brackets indicate that the timestamp
field is optional.
For each built-in dataset, Surprise also provides predefined readers which
are useful if you want to use a custom dataset that has the same format as
a built-in one (see the ``name`` parameter).
Args:
name(:obj:`string`, optional): If specified, a Reader for one of the
built-in datasets is returned and any other parameter is ignored.
Accepted values are 'ml-100k', 'ml-1m', and 'jester'. Default
is ``None``.
line_format(:obj:`string`): The fields names, in the order at which
they are encountered on a line. Please note that ``line_format`` is
always space-separated (use the ``sep`` parameter). Default is
``'user item rating'``.
sep(char): the separator between fields. Example : ``';'``.
rating_scale(:obj:`tuple`, optional): The rating scale used for every
rating. Default is ``(1, 5)``.
skip_lines(:obj:`int`, optional): Number of lines to skip at the
beginning of the file. Default is ``0``.
"""
def __init__(
self,
name=None,
line_format="user item rating",
sep=None,
rating_scale=(1, 5),
skip_lines=0,
):
if name:
try:
self.__init__(**BUILTIN_DATASETS[name].reader_params)
except KeyError:
raise ValueError(
"unknown reader "
+ name
+ ". Accepted values are "
+ ", ".join(BUILTIN_DATASETS.keys())
+ "."
)
else:
self.sep = sep
self.skip_lines = skip_lines
self.rating_scale = rating_scale
lower_bound, higher_bound = rating_scale
splitted_format = line_format.split()
entities = ["user", "item", "rating"]
if "timestamp" in splitted_format:
self.with_timestamp = True
entities.append("timestamp")
else:
self.with_timestamp = False
# check that all fields are correct
if any(field not in entities for field in splitted_format):
raise ValueError("line_format parameter is incorrect.")
self.indexes = [splitted_format.index(entity) for entity in entities]
def parse_line(self, line):
"""Parse a line.
Ratings are translated so that they are all strictly positive.
Args:
line(str): The line to parse
Returns:
tuple: User id, item id, rating and timestamp. The timestamp is set
to ``None`` if it does no exist.
"""
line = line.split(self.sep)
try:
if self.with_timestamp:
uid, iid, r, timestamp = (line[i].strip() for i in self.indexes)
else:
uid, iid, r = (line[i].strip() for i in self.indexes)
timestamp = None
except IndexError:
raise ValueError(
"Impossible to parse line. Check the line_format" " and sep parameters."
)
return uid, iid, float(r), timestamp