from __future__ import absolute_import, print_function import os import sys import tempfile import numpy from numpy.testing import TestCase, assert_, run_module_suite from scipy.weave import inline_tools, ext_tools from scipy.weave.build_tools import msvc_exists, gcc_exists from scipy.weave.catalog import unique_file from scipy.weave.numpy_scalar_spec import numpy_complex_scalar_converter from weave_test_utils import dec def unique_mod(d,file_name): f = os.path.basename(unique_file(d,file_name)) m = os.path.splitext(f)[0] return m #---------------------------------------------------------------------------- # Scalar conversion test classes # int, float, complex #---------------------------------------------------------------------------- class NumpyComplexScalarConverter(TestCase): compiler = '' def setUp(self): self.converter = numpy_complex_scalar_converter() @dec.slow def test_type_match_string(self): assert_(not self.converter.type_match('string')) @dec.slow def test_type_match_int(self): assert_(not self.converter.type_match(5)) @dec.slow def test_type_match_float(self): assert_(not self.converter.type_match(5.)) @dec.slow def test_type_match_complex128(self): assert_(self.converter.type_match(numpy.complex128(5.+1j))) @dec.slow def test_complex_var_in(self): mod_name = sys._getframe().f_code.co_name + self.compiler mod_name = unique_mod(test_dir,mod_name) mod = ext_tools.ext_module(mod_name) a = numpy.complex(1.+1j) code = "a=std::complex(2.,2.);" test = ext_tools.ext_function('test',code,['a']) mod.add_function(test) mod.compile(location=test_dir, compiler=self.compiler) exec('from ' + mod_name + ' import test') b = numpy.complex128(1.+1j) test(b) try: b = 1. test(b) except TypeError: pass try: b = 'abc' test(b) except TypeError: pass @dec.slow def test_complex_return(self): mod_name = sys._getframe().f_code.co_name + self.compiler mod_name = unique_mod(test_dir,mod_name) mod = ext_tools.ext_module(mod_name) a = 1.+1j code = """ a= a + std::complex(2.,2.); return_val = PyComplex_FromDoubles(a.real(),a.imag()); """ test = ext_tools.ext_function('test',code,['a']) mod.add_function(test) mod.compile(location=test_dir, compiler=self.compiler) exec('from ' + mod_name + ' import test') b = 1.+1j c = test(b) assert_(c == 3.+3j) @dec.slow def test_inline(self): a = numpy.complex128(1+1j) result = inline_tools.inline("return_val=1.0/a;",['a']) assert_(result == .5-.5j) for _n in dir(): if _n[-9:] == 'Converter': if msvc_exists(): exec("class Test%sMsvc(%s):\n compiler = 'msvc'" % (_n,_n)) else: exec("class Test%sUnix(%s):\n compiler = ''" % (_n,_n)) if gcc_exists(): exec("class Test%sGcc(%s):\n compiler = 'gcc'" % (_n,_n)) def setup_test_location(): test_dir = tempfile.mkdtemp() sys.path.insert(0,test_dir) return test_dir test_dir = setup_test_location() def teardown_test_location(): import tempfile test_dir = os.path.join(tempfile.gettempdir(),'test_files') if sys.path[0] == test_dir: sys.path = sys.path[1:] return test_dir if not msvc_exists(): for _n in dir(): if _n[:8] == 'TestMsvc': exec('del '+_n) else: for _n in dir(): if _n[:8] == 'TestUnix': exec('del '+_n) if not (gcc_exists() and msvc_exists() and sys.platform == 'win32'): for _n in dir(): if _n[:7] == 'TestGcc': exec('del '+_n) if __name__ == "__main__": run_module_suite()