"""Test functionality of mldata fetching utilities.""" import os import shutil import tempfile import scipy as sp from sklearn import datasets from sklearn.datasets import mldata_filename, fetch_mldata from sklearn.utils.testing import assert_in from sklearn.utils.testing import assert_not_in from sklearn.utils.testing import mock_mldata_urlopen from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import with_setup from sklearn.utils.testing import assert_array_equal tmpdir = None def setup_tmpdata(): # create temporary dir global tmpdir tmpdir = tempfile.mkdtemp() os.makedirs(os.path.join(tmpdir, 'mldata')) def teardown_tmpdata(): # remove temporary dir if tmpdir is not None: shutil.rmtree(tmpdir) def test_mldata_filename(): cases = [('datasets-UCI iris', 'datasets-uci-iris'), ('news20.binary', 'news20binary'), ('book-crossing-ratings-1.0', 'book-crossing-ratings-10'), ('Nile Water Level', 'nile-water-level'), ('MNIST (original)', 'mnist-original')] for name, desired in cases: assert_equal(mldata_filename(name), desired) @with_setup(setup_tmpdata, teardown_tmpdata) def test_download(): """Test that fetch_mldata is able to download and cache a data set.""" _urlopen_ref = datasets.mldata.urlopen datasets.mldata.urlopen = mock_mldata_urlopen({ 'mock': { 'label': sp.ones((150,)), 'data': sp.ones((150, 4)), }, }) try: mock = fetch_mldata('mock', data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "target", "data"]: assert_in(n, mock) assert_equal(mock.target.shape, (150,)) assert_equal(mock.data.shape, (150, 4)) assert_raises(datasets.mldata.HTTPError, fetch_mldata, 'not_existing_name') finally: datasets.mldata.urlopen = _urlopen_ref @with_setup(setup_tmpdata, teardown_tmpdata) def test_fetch_one_column(): _urlopen_ref = datasets.mldata.urlopen try: dataname = 'onecol' # create fake data set in cache x = sp.arange(6).reshape(2, 3) datasets.mldata.urlopen = mock_mldata_urlopen({dataname: {'x': x}}) dset = fetch_mldata(dataname, data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "data"]: assert_in(n, dset) assert_not_in("target", dset) assert_equal(dset.data.shape, (2, 3)) assert_array_equal(dset.data, x) # transposing the data array dset = fetch_mldata(dataname, transpose_data=False, data_home=tmpdir) assert_equal(dset.data.shape, (3, 2)) finally: datasets.mldata.urlopen = _urlopen_ref @with_setup(setup_tmpdata, teardown_tmpdata) def test_fetch_multiple_column(): _urlopen_ref = datasets.mldata.urlopen try: # create fake data set in cache x = sp.arange(6).reshape(2, 3) y = sp.array([1, -1]) z = sp.arange(12).reshape(4, 3) # by default dataname = 'threecol-default' datasets.mldata.urlopen = mock_mldata_urlopen({ dataname: ( { 'label': y, 'data': x, 'z': z, }, ['z', 'data', 'label'], ), }) dset = fetch_mldata(dataname, data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "target", "data", "z"]: assert_in(n, dset) assert_not_in("x", dset) assert_not_in("y", dset) assert_array_equal(dset.data, x) assert_array_equal(dset.target, y) assert_array_equal(dset.z, z.T) # by order dataname = 'threecol-order' datasets.mldata.urlopen = mock_mldata_urlopen({ dataname: ({'y': y, 'x': x, 'z': z}, ['y', 'x', 'z']), }) dset = fetch_mldata(dataname, data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "target", "data", "z"]: assert_in(n, dset) assert_not_in("x", dset) assert_not_in("y", dset) assert_array_equal(dset.data, x) assert_array_equal(dset.target, y) assert_array_equal(dset.z, z.T) # by number dataname = 'threecol-number' datasets.mldata.urlopen = mock_mldata_urlopen({ dataname: ({'y': y, 'x': x, 'z': z}, ['z', 'x', 'y']), }) dset = fetch_mldata(dataname, target_name=2, data_name=0, data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "target", "data", "x"]: assert_in(n, dset) assert_not_in("y", dset) assert_not_in("z", dset) assert_array_equal(dset.data, z) assert_array_equal(dset.target, y) # by name dset = fetch_mldata(dataname, target_name='y', data_name='z', data_home=tmpdir) for n in ["COL_NAMES", "DESCR", "target", "data", "x"]: assert_in(n, dset) assert_not_in("y", dset) assert_not_in("z", dset) finally: datasets.mldata.urlopen = _urlopen_ref