-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add data format #872
Closed
Closed
Add data format #872
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
a4044bb
add cifar
beckett1124 22a8d06
update cifar
beckett1124 7192a6b
update cifar
beckett1124 0913bbc
add mnist and amazon
beckett1124 a4ed798
update amazon and mnist
beckett1124 ce0f5b0
add other data
beckett1124 1373977
update code
ee9b1c6
update
c53599f
update
6173153
Refine amazon_product_reviews.py
reyoung a97f44c
Merge pull request #1 from reyoung/feature/data_api
beckett1124 7972f74
Add md5 checks.
reyoung 294f298
Add preprocess method for amazon reviews
reyoung 20c96b7
Done with amazon product reviews.
reyoung 60b6ef5
Merge pull request #2 from reyoung/feature/data_api
beckett1124 0c76c64
add new file path
9a803d0
update file path
c6a2600
add Data md
3fa7f21
new
bee88c9
updata amazon & cifar & mnist
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
### 数据集 | ||
|
||
Paddle目前提供了很多demo,且各demo运行时需要从原生网站下载其数据,并进行复杂的预处理过程,整个过程会耗费大量时间。同时为了方便大家用Paddle做实验的时候,可以直接访问这些预处理好的数据,我们提供一套Python库。采用import数据源的方式(如:paddle.data.amazon_product_reviews)来简化获取训练所需数据的时间;但是如果你习惯自己处理原生数据,我们依然提供原生数据接口来满足你的需求。 | ||
|
||
## 接口设计 | ||
数据集的导入通过import paddle.data.amazon_product_reviews 来实现,你可以直接通过load_data(category=None, | ||
directory=None)获取你所需的数据集。考虑到类似Amazon的数据类型不止一种,通过category你可以选择控制所需要的数据源;如果你不指定数据源,默认为"Electronics"。directory用来指定下载路径,如果你不指定下载路径,默认为"~/paddle_data/amazon"。通过load_data()导入的数据源data为object,他是我们预处理的numpy格式数据,直接通过data.train_data()获取训练数据或者通过data.test_data()获取测试数据。你还可以打印训练数据和测试数据的数据信息, | ||
|
||
```python | ||
for each_train_data in data.train_data(): | ||
print each_train_data | ||
``` | ||
即可。 | ||
|
||
具体的demo使用情况如下: | ||
```python | ||
import paddle.data.amazon_product_reviews as raw | ||
|
||
raw.data(batch_size=10) | ||
``` | ||
你也可以打印出各数据集的数据信息: | ||
如果是测试集或者训练数据集,可以这么打印 | ||
```python | ||
import paddle.data.amazon_product_reviews as raw | ||
|
||
raw.test_data(batch_size=10) | ||
raw.train_data(batch_size=10) | ||
|
||
``` | ||
|
||
打印出来的数据信息都是预处理之后的numpy格式的数据: | ||
```python | ||
(array([1370072, 884914, 1658622, 1562803, 1579, 369164, 1129091, | ||
1073545, 1410234, 857854, 672274, 884920, 1078270, 1410234, | ||
777903, 1352600, 497103, 132906, 239745, 65294, 1502324, | ||
1165610, 204273, 1610806, 942942, 709056, 452087, 118093, | ||
1410234], dtype=int32), array([ True], dtype=bool)) | ||
(array([ 777903, 713632, 452087, 1647686, 877980, 294749, 1575945, | ||
662947, 1431519, 462950, 452087, 902916, 479242, 294749, | ||
1278816, 672274, 1579, 394865, 1129091, 1352600, 294749, | ||
1073545], dtype=int32), array([ True], dtype=bool)) | ||
|
||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
""" | ||
The :mod:`paddle.datasets` module includes utilities to load datasets, | ||
including methods to load and fetch popular reference datasets. It also | ||
features some artificial data generators. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,337 @@ | ||
# /usr/bin/env python | ||
# -*- coding:utf-8 -*- | ||
|
||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
A utility for fetching, reading amazon product review data set. | ||
|
||
http://jmcauley.ucsd.edu/data/amazon/ | ||
""" | ||
|
||
import os | ||
from http_download import download | ||
from logger import logger | ||
import gzip | ||
import json | ||
import hashlib | ||
import nltk | ||
import collections | ||
import h5py | ||
import numpy | ||
import random | ||
|
||
|
||
BASE_URL = 'http://snap.stanford.edu/data/' \ | ||
'amazon/productGraph/categoryFiles/reviews_%s_5.json.gz' | ||
|
||
DATASET_LABEL = 'label' | ||
DATASET_SENTENCE = 'sentence' | ||
|
||
positive_threshold = 5 | ||
negative_threshold = 2 | ||
|
||
|
||
|
||
class Categories(object): | ||
Books = "Books" | ||
Electronics = "Electronics" | ||
MoviesAndTV = "Movies_and_TV" | ||
CDsAndVinyl = "CDs_and_Vinyl" | ||
ClothingShoesAndJewelry = "Clothing_Shoes_and_Jewelry" | ||
HomeAndKitchen = "Home_and_Kitchen" | ||
KindleStore = "Kindle_Store" | ||
SportsAndOutdoors = "Sports_and_Outdoors" | ||
CellPhonesAndAccessories = "Cell_Phones_and_Accessories" | ||
HealthAndPersonalCare = "Health_and_Personal_Care" | ||
ToysAndGames = "Toys_and_Games" | ||
VideoGames = "Video_Games" | ||
ToolsAndHomeImprovement = "Tools_and_Home_Improvement" | ||
Beauty = "Beauty" | ||
AppsForAndroid = "Apps_for_Android" | ||
OfficeProducts = "Office_Products" | ||
PetSupplies = "Pet_Supplies" | ||
Automotive = "Automotive" | ||
GroceryAndGourmetFood = "Grocery_and_Gourmet" | ||
PatioLawnAndGarden = "Patio_Lawn_and_Garden" | ||
Baby = "Baby" | ||
DigitalMusic = "Digital_Music" | ||
MusicalInstruments = "Musical_Instruments" | ||
AmazonInstantVideo = "Amazon_Instant_Video" | ||
|
||
__md5__ = dict() | ||
|
||
__md5__[AmazonInstantVideo] = '10812e43e99c345f63333d8ee10aef6a' | ||
__md5__[AppsForAndroid] = 'a7d1ae198b862eea6910fe45c842b0c6' | ||
__md5__[Automotive] = '757fdb1ab2c5e2fc0934047721082011' | ||
__md5__[Baby] = '7698a4179a1d8385e946ed9083490d22' | ||
__md5__[Beauty] = '5d2ccdcd86641efcfbae344317c10829' | ||
__md5__[Books] = 'bc1e2aa650fe51f978e9d3a7a4834bc6' | ||
__md5__[CDsAndVinyl] = '82bffdc956e76c32fa655b98eca9576b' | ||
__md5__[CellPhonesAndAccessories] = '903a19524d874970a2f0ae32a175a48f' | ||
__md5__[ClothingShoesAndJewelry] = 'b333fba48651ea2309288aeb51f8c6e4' | ||
__md5__[DigitalMusic] = '35e62f7a7475b53714f9b177d9dae3e7' | ||
__md5__[Electronics] = 'e4524af6c644cd044b1969bac7b62b2a' | ||
__md5__[GroceryAndGourmetFood] = 'd8720f98ea82c71fa5c1223f39b6e3d9' | ||
__md5__[HealthAndPersonalCare] = '352ea1f780a8629783220c7c9a9f7575' | ||
__md5__[HomeAndKitchen] = '90221797ccc4982f57e6a5652bea10fc' | ||
__md5__[KindleStore] = 'b608740c754287090925a1a186505353' | ||
__md5__[MoviesAndTV] = 'd3bb01cfcda2602c07bcdbf1c4222997' | ||
__md5__[MusicalInstruments] = '8035b6e3f9194844785b3f4cee296577' | ||
__md5__[OfficeProducts] = '1b7e64c707ecbdcdeca1efa09b716499' | ||
__md5__[PatioLawnAndGarden] = '4d2669abc5319d0f073ec3c3a85f18af' | ||
__md5__[PetSupplies] = '40568b32ca1536a4292e8410c5b9de12' | ||
__md5__[SportsAndOutdoors] = '1df6269552761c82aaec9667bf9a0b1d' | ||
__md5__[ToolsAndHomeImprovement] = '80bca79b84621d4848a88dcf37a1c34b' | ||
__md5__[ToysAndGames] = 'dbd07c142c47473c6ee22b535caee81f' | ||
__md5__[VideoGames] = '730612da2d6a93ed19f39a808b63993e' | ||
|
||
|
||
__all__ = ['fetch', 'data', 'train_data', 'test_data'] | ||
|
||
|
||
def calculate_md5(fn): | ||
h = hashlib.md5() | ||
with open(fn, 'rb') as f: | ||
for chunk in iter(lambda: f.read(4096), b""): | ||
h.update(chunk) | ||
return h.hexdigest() | ||
|
||
|
||
def fetch(category=None, directory=None): | ||
""" | ||
According to the source name,set the download path for source, | ||
download the data from the source url, and return the download path to fetch | ||
for training api. | ||
:param category: | ||
:param directory: | ||
:return: | ||
""" | ||
if category is None: | ||
category = Categories.Electronics | ||
|
||
if directory is None: | ||
directory = os.path.expanduser( | ||
os.path.join('~', 'paddle_data', 'amazon')) | ||
|
||
if not os.path.exists(directory): | ||
os.makedirs(directory) | ||
|
||
fn = os.path.join(directory, '%s.json.gz' % category) | ||
|
||
if os.path.exists(fn) and \ | ||
calculate_md5(fn) == Categories.__md5__[category]: | ||
# already download. | ||
return fn | ||
|
||
logger.info("Downloading amazon review dataset for %s category" % category) | ||
return download(BASE_URL % category, fn) | ||
|
||
|
||
def preprocess(category=None, directory=None): | ||
""" | ||
Download and preprocess amazon reviews data set. Save the preprocessed | ||
result to hdf5 file. | ||
|
||
In preprocess, it uses nltk to tokenize english sentence. It is slightly | ||
different from moses. But nltk is a pure python library, it could be | ||
integrated well with Paddle. | ||
|
||
:return: hdf5 file name. | ||
""" | ||
if category is None: | ||
category = Categories.Electronics | ||
|
||
if directory is None: | ||
directory = os.path.expanduser( | ||
os.path.join('~', 'paddle_data', 'amazon')) | ||
|
||
preprocess_fn = os.path.join(directory, '%s.hdf5' % category) | ||
raw_file_fn = fetch(category, directory) | ||
|
||
word_dict = collections.defaultdict(int) | ||
if not os.path.exists(preprocess_fn): # already preprocessed | ||
with gzip.open(raw_file_fn, mode='r') as f: | ||
for sample_num, line in enumerate(f): | ||
txt = json.loads(line)['reviewText'] | ||
try: # automatically download nltk tokenizer data. | ||
words = nltk.tokenize.word_tokenize(txt, 'english') | ||
except LookupError: | ||
nltk.download('punkt') | ||
words = nltk.tokenize.word_tokenize(txt, 'english') | ||
for each_word in words: | ||
word_dict[each_word] += 1 | ||
sample_num += 1 | ||
|
||
word_dict_sorted = [] | ||
for each in word_dict: | ||
word_dict_sorted.append((each, word_dict[each])) | ||
|
||
word_dict_sorted.sort(cmp=lambda a, b: a[1] > b[1]) | ||
|
||
word_dict = dict() | ||
|
||
h5file = h5py.File(preprocess_fn, 'w') | ||
try: | ||
word_dict_h5 = h5file.create_dataset( | ||
'word_dict', | ||
shape=(len(word_dict_sorted), ), | ||
dtype=h5py.special_dtype(vlen=str)) | ||
for i, each in enumerate(word_dict_sorted): | ||
word_dict_h5[i] = each[0] | ||
word_dict[each[0]] = i | ||
|
||
sentence = h5file.create_dataset( | ||
DATASET_SENTENCE, | ||
shape=(sample_num, ), | ||
dtype=h5py.special_dtype(vlen=numpy.int32)) | ||
|
||
label = h5file.create_dataset( | ||
DATASET_LABEL, shape=(sample_num, 1), dtype=numpy.int8) | ||
|
||
with gzip.open(raw_file_fn, mode='r') as f: | ||
for i, line in enumerate(f): | ||
obj = json.loads(line) | ||
txt = obj['reviewText'] | ||
score = numpy.int8(obj['overall']) | ||
words = nltk.tokenize.word_tokenize(txt, 'english') | ||
words = numpy.array( | ||
[word_dict[w] for w in words], dtype=numpy.int32) | ||
sentence[i] = words | ||
label[i] = score | ||
|
||
finally: | ||
h5file.close() | ||
return preprocess_fn | ||
|
||
|
||
def data(batch_size, category=None, directory=None): | ||
""" | ||
|
||
:param batch_size: | ||
:param category: | ||
:param directory: | ||
:return: | ||
""" | ||
if category is None: | ||
category = Categories.Electronics | ||
|
||
if directory is None: | ||
directory = os.path.expanduser( | ||
os.path.join('~', 'paddle_data', 'amazon')) | ||
|
||
fn = preprocess(category=category, directory=directory) | ||
datasets = h5py.File(fn, 'r') | ||
|
||
label = datasets[DATASET_LABEL] | ||
sentence = datasets[DATASET_SENTENCE] | ||
|
||
if label.shape[0] <= batch_size: | ||
lens = label.shape[0] | ||
else: | ||
lens = batch_size | ||
|
||
for index in range(lens): | ||
if label[index] >= positive_threshold: | ||
print (numpy.array(sentence[index]), label[index] >= positive_threshold) | ||
elif label[index] <= negative_threshold: | ||
print (numpy.array(sentence[index]), label[index] <= negative_threshold) | ||
|
||
|
||
def test_data(batch_size, category=None, directory=None): | ||
""" | ||
|
||
:param batch_size: | ||
:param category: | ||
:param directory: | ||
:return: | ||
""" | ||
if category is None: | ||
category = Categories.Electronics | ||
|
||
if directory is None: | ||
directory = os.path.expanduser( | ||
os.path.join('~', 'paddle_data', 'amazon')) | ||
|
||
fn = preprocess(category=category, directory=directory) | ||
datasets = h5py.File(fn, 'r') | ||
|
||
label = datasets[DATASET_LABEL] | ||
sentence = datasets[DATASET_SENTENCE] | ||
|
||
if label.shape[0] <= batch_size: | ||
lens = label.shape[0] | ||
else: | ||
lens = batch_size | ||
|
||
positive_idx = [] | ||
negative_idx = [] | ||
for i, lbl in enumerate(label): | ||
if label[i] >= positive_threshold: | ||
positive_idx.append(i) | ||
elif lbl <= negative_threshold: | ||
negative_idx.append(i) | ||
|
||
__test_set__ = positive_idx[:lens] + negative_idx[:lens] | ||
|
||
random.shuffle(__test_set__) | ||
|
||
for index in range(lens): | ||
print (numpy.array(sentence[index]), label[index] >= positive_threshold) | ||
|
||
|
||
def train_data(batch_size, category=None, directory=None): | ||
""" | ||
|
||
:param batch_size: | ||
:param category: | ||
:param directory: | ||
:return: | ||
""" | ||
if category is None: | ||
category = Categories.Electronics | ||
|
||
if directory is None: | ||
directory = os.path.expanduser( | ||
os.path.join('~', 'paddle_data', 'amazon')) | ||
|
||
fn = preprocess(category=category, directory=directory) | ||
datasets = h5py.File(fn, 'r') | ||
|
||
label = datasets[DATASET_LABEL] | ||
sentence = datasets[DATASET_SENTENCE] | ||
|
||
if label.shape[0] <= batch_size: | ||
lens = label.shape[0] | ||
else: | ||
lens = batch_size | ||
|
||
positive_idx = [] | ||
negative_idx = [] | ||
for i, lbl in enumerate(label): | ||
if label[i] >= positive_threshold: | ||
positive_idx.append(i) | ||
elif lbl <= negative_threshold: | ||
negative_idx.append(i) | ||
__train_set__ = positive_idx[lens:] + negative_idx[lens:] | ||
|
||
random.shuffle(__train_set__) | ||
|
||
for index in range(lens): | ||
print (numpy.array(sentence[index]), label[index] >= positive_threshold) | ||
|
||
|
||
if __name__ == '__main__': | ||
data(10) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.datasets
==>paddle.data
?