Python: Mocking a context manager

I don’t understand why I can’t mock NamedTemporaryFile.name in this example:

from mock import Mock, patch
import unittest
import tempfile

def myfunc():
    with tempfile.NamedTemporaryFile() as mytmp:
        return mytmp.name

class TestMock(unittest.TestCase):
    @patch('tempfile.NamedTemporaryFile')
    def test_cm(self, mock_tmp):
        mytmpname = 'abcde'
        mock_tmp.__enter__.return_value.name = mytmpname
        self.assertEqual(myfunc(), mytmpname)

Test results in:

AssertionError: <MagicMock name='NamedTemporaryFile().__enter__().name' id='140275675011280'> != 'abcde'

Answers:

Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.

Method 1

You are setting the wrong mock: mock_tmp is not the context manager, but instead returns a context manager. Replace your setup line with:

mock_tmp.return_value.__enter__.return_value.name = mytmpname

and your test will work.

Method 2

To expand on Nathaniel’s answer, this code block

with tempfile.NamedTemporaryFile() as mytmp:
    return mytmp.name

effectively does three things

# Firstly, it calls NamedTemporaryFile, to create a new instance of the class.
context_manager = tempfile.NamedTemporaryFile()  

# Secondly, it calls __enter__ on the context manager instance.
mytmp = context_manager.__enter__()  

# Thirdly, we are now "inside" the context and can do some work. 
return mytmp.name

When you replace tempfile.NamedTemporaryFile with an instance of Mock or MagicMock

context_manager = mock_tmp()
# This first line, above, will call mock_tmp().
# Therefore we need to set the return_value with
# mock_tmp.return_value

mytmp = context_manager.__enter__()
# This will call mock_tmp.return_value.__enter__() so we need to set 
# mock_tmp.return_value.__enter__.return_value

return mytmp.name
# This will access mock_tmp.return_value.__enter__.return_value.name

Method 3

Extending Peter K’s answer using pytest and the mocker fixture.

def myfunc():
    with tempfile.NamedTemporaryFile(prefix='fileprefix') as fh:
        return fh.name


def test_myfunc(mocker):
    mocker.patch('tempfile.NamedTemporaryFile').return_value.__enter__.return_value.name = 'tempfilename'
    assert myfunc() == 'tempfilename'

Method 4

Here is an alternative with pytest and mocker fixture, which is a common practice as well:

def test_myfunc(mocker):
    mock_tempfile = mocker.MagicMock(name='tempfile')
    mocker.patch(__name__ + '.tempfile', new=mock_tempfile)
    mytmpname = 'abcde'
    mock_tempfile.NamedTemporaryFile.return_value.__enter__.return_value.name = mytmpname
    assert myfunc() == mytmpname

Method 5

I extended hmobrienv’s answer to a small working program

import tempfile
import pytest


def myfunc():
    with tempfile.NamedTemporaryFile(prefix="fileprefix") as fh:
        return fh.name


def test_myfunc(mocker):
    mocker.patch("tempfile.NamedTemporaryFile").return_value.__enter__.return_value.name = "tempfilename"
    assert myfunc() == "tempfilename"


if __name__ == "__main__":
    pytest.main(args=[__file__])

Method 6

Another possibility is to use a factory to create an object that implements the context manager interface:

import unittest
import unittest.mock
import tempfile


def myfunc():
    with tempfile.NamedTemporaryFile() as mytmp:
        return mytmp.name


def mock_named_temporary_file(tmpname):
    class MockNamedTemporaryFile(object):
        def __init__(self, *args, **kwargs):
            self.name = tmpname

        def __enter__(self):
            return self

        def __exit__(self, type, value, traceback):
            pass

    return MockNamedTemporaryFile()


class TestMock(unittest.TestCase):
    @unittest.mock.patch("tempfile.NamedTemporaryFile")
    def test_cm(self, mock_tmp):
        mytmpname = "abcde"
        mock_tmp.return_value = mock_named_temporary_file(mytmpname)
        self.assertEqual(myfunc(), mytmpname)


All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0

0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x