diff --git a/lib/ansible/module_utils/postgres.py b/lib/ansible/module_utils/postgres.py index 5d6232930d9..ab982135b4d 100644 --- a/lib/ansible/module_utils/postgres.py +++ b/lib/ansible/module_utils/postgres.py @@ -66,52 +66,24 @@ def ensure_required_libs(module): module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter') -def connect_to_db(module, autocommit=False, fail_on_conn=True, warn_db_default=True): - """Return psycopg2 connection object. - - Keyword arguments: - module -- object of ansible.module_utils.basic.AnsibleModule class - autocommit -- commit automatically (default False) - fail_on_conn -- fail if connection failed or just warn and return None (default True) - warn_db_default -- warn that the default DB is used (default True) - """ - ensure_required_libs(module) - - # To use defaults values, keyword arguments must be absent, so - # check which values are empty and don't include in the **kw - # dictionary - params_map = { - "login_host": "host", - "login_user": "user", - "login_password": "password", - "port": "port", - "ssl_mode": "sslmode", - "ca_cert": "sslrootcert" - } +def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True): + """Connect to a PostgreSQL database. - # Might be different in the modules: - if module.params.get('db'): - params_map['db'] = 'database' - elif module.params.get('database'): - params_map['database'] = 'database' - elif module.params.get('login_db'): - params_map['login_db'] = 'database' - else: - if warn_db_default: - module.warn('Database name has not been passed, ' - 'used default database to connect to.') + Return psycopg2 connection object. - kw = dict((params_map[k], v) for (k, v) in iteritems(module.params) - if k in params_map and v != '' and v is not None) + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + conn_params (dict) -- dictionary with connection parameters - # If a login_unix_socket is specified, incorporate it here. - is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost" - if is_localhost and module.params["login_unix_socket"] != "": - kw["host"] = module.params["login_unix_socket"] + Kwargs: + autocommit (bool) -- commit automatically (default False) + fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True) + """ + ensure_required_libs(module) db_connection = None try: - db_connection = psycopg2.connect(**kw) + db_connection = psycopg2.connect(**conn_params) if autocommit: if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'): db_connection.set_session(autocommit=True) @@ -179,3 +151,49 @@ def exec_sql(obj, query, ddl=False, add_to_executed=True): except Exception as e: obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e))) return False + + +def get_conn_params(module, params_dict, warn_db_default=True): + """Get connection parameters from the passed dictionary. + + Return a dictionary with parameters to connect to PostgreSQL server. + + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + params_dict (dict) -- dictionary with variables + + Kwargs: + warn_db_default (bool) -- warn that the default DB is used (default True) + """ + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the return dictionary + params_map = { + "login_host": "host", + "login_user": "user", + "login_password": "password", + "port": "port", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + + # Might be different in the modules: + if params_dict.get('db'): + params_map['db'] = 'database' + elif params_dict.get('database'): + params_map['database'] = 'database' + elif params_dict.get('login_db'): + params_map['login_db'] = 'database' + else: + if warn_db_default: + module.warn('Database name has not been passed, ' + 'used default database to connect to.') + + kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict) + if k in params_map and v != '' and v is not None) + + # If a login_unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost" + if is_localhost and params_dict["login_unix_socket"] != "": + kw["host"] = params_dict["login_unix_socket"] + + return kw diff --git a/lib/ansible/modules/database/postgresql/postgresql_copy.py b/lib/ansible/modules/database/postgresql/postgresql_copy.py index 45810b31dd6..cc369b6efff 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_copy.py +++ b/lib/ansible/modules/database/postgresql/postgresql_copy.py @@ -178,6 +178,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) from ansible.module_utils.six import iteritems @@ -351,7 +352,8 @@ def main(): module.fail_json(msg='src param is necessary with copy_to') # Connect to DB and make cursor object: - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor(cursor_factory=DictCursor) ############## diff --git a/lib/ansible/modules/database/postgresql/postgresql_ext.py b/lib/ansible/modules/database/postgresql/postgresql_ext.py index 649d894b371..f37672a2d1b 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_ext.py +++ b/lib/ansible/modules/database/postgresql/postgresql_ext.py @@ -143,7 +143,11 @@ except ImportError: pass from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_native from ansible.module_utils.database import pg_quote_identifier @@ -216,7 +220,8 @@ def main(): cascade = module.params["cascade"] changed = False - db_connection = connect_to_db(module, autocommit=True) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) try: diff --git a/lib/ansible/modules/database/postgresql/postgresql_idx.py b/lib/ansible/modules/database/postgresql/postgresql_idx.py index 6d711ebf9cd..b995e3a4699 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_idx.py +++ b/lib/ansible/modules/database/postgresql/postgresql_idx.py @@ -230,6 +230,7 @@ from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -474,7 +475,8 @@ def main(): if cascade and state != 'absent': module.fail_json(msg="cascade parameter used only with state=absent") - db_connection = connect_to_db(module, autocommit=True) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) # Set defaults: diff --git a/lib/ansible/modules/database/postgresql/postgresql_info.py b/lib/ansible/modules/database/postgresql/postgresql_info.py index 08a5e0daf0c..567328c64d4 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_info.py +++ b/lib/ansible/modules/database/postgresql/postgresql_info.py @@ -475,7 +475,11 @@ except ImportError: pass from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_native @@ -502,7 +506,8 @@ class PgDbConn(object): Note: connection parameters are passed by self.module object. """ - self.db_conn = connect_to_db(self.module, warn_db_default=False) + conn_params = get_conn_params(self.module, self.module.params, warn_db_default=False) + self.db_conn = connect_to_db(self.module, conn_params) return self.db_conn.cursor(cursor_factory=DictCursor) def reconnect(self, dbname): diff --git a/lib/ansible/modules/database/postgresql/postgresql_lang.py b/lib/ansible/modules/database/postgresql/postgresql_lang.py index 17ef639d844..c601d094aa6 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_lang.py +++ b/lib/ansible/modules/database/postgresql/postgresql_lang.py @@ -170,7 +170,11 @@ queries: ''' from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_native from ansible.module_utils.database import pg_quote_identifier @@ -254,7 +258,8 @@ def main(): cascade = module.params["cascade"] fail_on_drop = module.params["fail_on_drop"] - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor() changed = False diff --git a/lib/ansible/modules/database/postgresql/postgresql_membership.py b/lib/ansible/modules/database/postgresql/postgresql_membership.py index a121bccc76d..0dd5257fb58 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_membership.py +++ b/lib/ansible/modules/database/postgresql/postgresql_membership.py @@ -147,6 +147,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -284,7 +285,8 @@ def main(): fail_on_role = module.params['fail_on_role'] state = module.params['state'] - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor(cursor_factory=DictCursor) ############## diff --git a/lib/ansible/modules/database/postgresql/postgresql_owner.py b/lib/ansible/modules/database/postgresql/postgresql_owner.py index 37ec8256c5a..d0e0cc92645 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_owner.py +++ b/lib/ansible/modules/database/postgresql/postgresql_owner.py @@ -161,6 +161,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -415,7 +416,8 @@ def main(): reassign_owned_by = module.params['reassign_owned_by'] fail_on_role = module.params['fail_on_role'] - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor(cursor_factory=DictCursor) ############## diff --git a/lib/ansible/modules/database/postgresql/postgresql_ping.py b/lib/ansible/modules/database/postgresql/postgresql_ping.py index 12bb3792bf4..f72174cad8c 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_ping.py +++ b/lib/ansible/modules/database/postgresql/postgresql_ping.py @@ -82,6 +82,7 @@ from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -138,7 +139,8 @@ def main(): server_version=dict(), ) - db_connection = connect_to_db(module, fail_on_conn=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, fail_on_conn=False) if db_connection is not None: cursor = db_connection.cursor(cursor_factory=DictCursor) diff --git a/lib/ansible/modules/database/postgresql/postgresql_query.py b/lib/ansible/modules/database/postgresql/postgresql_query.py index 2823afefa1c..21f80543710 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_query.py +++ b/lib/ansible/modules/database/postgresql/postgresql_query.py @@ -146,7 +146,11 @@ except ImportError: pass from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_native @@ -189,7 +193,8 @@ def main(): except Exception as e: module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e))) - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor(cursor_factory=DictCursor) # Prepare args: diff --git a/lib/ansible/modules/database/postgresql/postgresql_schema.py b/lib/ansible/modules/database/postgresql/postgresql_schema.py index 18b071c47ee..44ae53a7079 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_schema.py +++ b/lib/ansible/modules/database/postgresql/postgresql_schema.py @@ -129,7 +129,11 @@ except ImportError: pass from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils.database import SQLParseError, pg_quote_identifier from ansible.module_utils._text import to_native @@ -234,7 +238,8 @@ def main(): cascade_drop = module.params["cascade_drop"] changed = False - db_connection = connect_to_db(module, autocommit=True) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) try: diff --git a/lib/ansible/modules/database/postgresql/postgresql_sequence.py b/lib/ansible/modules/database/postgresql/postgresql_sequence.py index 0890ecc7460..ec8806dc3d8 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_sequence.py +++ b/lib/ansible/modules/database/postgresql/postgresql_sequence.py @@ -287,6 +287,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -498,7 +499,8 @@ def main(): # Change autocommit to False if check_mode: autocommit = not module.check_mode # Connect to DB and make cursor object: - db_connection = connect_to_db(module, autocommit=autocommit) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=autocommit) cursor = db_connection.cursor(cursor_factory=DictCursor) ############## diff --git a/lib/ansible/modules/database/postgresql/postgresql_set.py b/lib/ansible/modules/database/postgresql/postgresql_set.py index 57b8b31cd6c..0b7259a0f34 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_set.py +++ b/lib/ansible/modules/database/postgresql/postgresql_set.py @@ -165,7 +165,11 @@ except Exception: from copy import deepcopy from ansible.module_utils.basic import AnsibleModule -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_native PG_REQ_VER = 90400 @@ -304,7 +308,8 @@ def main(): if not value and not reset: module.fail_json(msg="%s: at least one of value or reset param must be specified" % name) - db_connection = connect_to_db(module, autocommit=True, warn_db_default=False) + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) kw = {} @@ -397,7 +402,7 @@ def main(): # Reconnect and recheck current value: if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'): - db_connection = connect_to_db(module, autocommit=True) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) res = param_get(cursor, module, name) diff --git a/lib/ansible/modules/database/postgresql/postgresql_slot.py b/lib/ansible/modules/database/postgresql/postgresql_slot.py index 3427764bd66..61ed8bfc84a 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_slot.py +++ b/lib/ansible/modules/database/postgresql/postgresql_slot.py @@ -152,6 +152,7 @@ from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -242,7 +243,8 @@ def main(): if immediately_reserve and slot_type == 'logical': module.fail_json(msg="Module parameters immediately_reserve and slot_type=logical are mutually exclusive") - db_connection = connect_to_db(module, autocommit=True) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) ################################## diff --git a/lib/ansible/modules/database/postgresql/postgresql_table.py b/lib/ansible/modules/database/postgresql/postgresql_table.py index 20b99d77abd..94c9f85360d 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_table.py +++ b/lib/ansible/modules/database/postgresql/postgresql_table.py @@ -240,6 +240,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -514,7 +515,8 @@ def main(): if including and not like: module.fail_json(msg="%s: including param needs like param specified" % table) - db_connection = connect_to_db(module, autocommit=False) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) cursor = db_connection.cursor(cursor_factory=DictCursor) if storage_params: diff --git a/lib/ansible/modules/database/postgresql/postgresql_tablespace.py b/lib/ansible/modules/database/postgresql/postgresql_tablespace.py index e34136f29b5..fc1c6ecf651 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_tablespace.py +++ b/lib/ansible/modules/database/postgresql/postgresql_tablespace.py @@ -176,6 +176,7 @@ from ansible.module_utils.database import pg_quote_identifier from ansible.module_utils.postgres import ( connect_to_db, exec_sql, + get_conn_params, postgres_common_argument_spec, ) @@ -394,7 +395,8 @@ def main(): module.fail_json(msg="state=absent is mutually exclusive location, " "owner, rename_to, and set") - db_connection = connect_to_db(module, autocommit=True) + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=True) cursor = db_connection.cursor(cursor_factory=DictCursor) # Change autocommit to False if check_mode: diff --git a/lib/ansible/modules/database/postgresql/postgresql_user.py b/lib/ansible/modules/database/postgresql/postgresql_user.py index 4242aac82b0..b328d072ff2 100644 --- a/lib/ansible/modules/database/postgresql/postgresql_user.py +++ b/lib/ansible/modules/database/postgresql/postgresql_user.py @@ -244,7 +244,11 @@ except ImportError: from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.database import pg_quote_identifier, SQLParseError -from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) from ansible.module_utils._text import to_bytes, to_native from ansible.module_utils.six import iteritems @@ -801,7 +805,8 @@ def main(): conn_limit = module.params["conn_limit"] role_attr_flags = module.params["role_attr_flags"] - db_connection = connect_to_db(module, warn_db_default=False) + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params) cursor = db_connection.cursor(cursor_factory=DictCursor) try: diff --git a/test/units/module_utils/postgresql/test_postgres.py b/test/units/module_utils/postgresql/test_postgres.py index a895106aa8e..d0ca60c1d25 100644 --- a/test/units/module_utils/postgresql/test_postgres.py +++ b/test/units/module_utils/postgresql/test_postgres.py @@ -6,6 +6,33 @@ import pytest import ansible.module_utils.postgres as pg +INPUT_DICT = dict( + session_role=dict(default=''), + login_user=dict(default='postgres'), + login_password=dict(default='test', no_log=True), + login_host=dict(default='test'), + login_unix_socket=dict(default=''), + port=dict(type='int', default=5432, aliases=['login_port']), + ssl_mode=dict( + default='prefer', + choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full'] + ), + ca_cert=dict(aliases=['ssl_rootcert']), +) + +EXPECTED_DICT = dict( + user=dict(default='postgres'), + password=dict(default='test', no_log=True), + host=dict(default='test'), + port=dict(type='int', default=5432, aliases=['login_port']), + sslmode=dict( + default='prefer', + choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full'] + ), + sslrootcert=dict(aliases=['ssl_rootcert']), +) + + class TestPostgresCommonArgSpec(): """ @@ -154,6 +181,24 @@ class TestEnsureReqLibs(): assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg +@pytest.fixture(scope='class') +def m_ansible_module(): + """Return an object of dummy AnsibleModule class.""" + class DummyAnsibleModule(): + def __init__(self): + self.params = pg.postgres_common_argument_spec() + self.err_msg = '' + self.warn_msg = '' + + def fail_json(self, msg): + self.err_msg = msg + + def warn(self, msg): + self.warn_msg = msg + + return DummyAnsibleModule() + + class TestConnectToDb(): """ @@ -168,29 +213,13 @@ class TestConnectToDb(): 2. Types of return objects (db_connection and cursor). """ - @pytest.fixture(scope='class') - def m_ansible_module(self): - """Return an object of dummy AnsibleModule class.""" - class DummyAnsibleModule(): - def __init__(self): - self.params = pg.postgres_common_argument_spec() - self.err_msg = '' - self.warn_msg = '' - - def fail_json(self, msg): - self.err_msg = msg - - def warn(self, msg): - self.warn_msg = msg - - return DummyAnsibleModule() - def test_connect_to_db(self, m_ansible_module, monkeypatch, m_psycopg2): """Test connect_to_db(), common test.""" monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) - db_connection = pg.connect_to_db(m_ansible_module) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) @@ -205,7 +234,8 @@ class TestConnectToDb(): monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) m_ansible_module.params['session_role'] = 'test_role' - db_connection = pg.connect_to_db(m_ansible_module) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) @@ -214,25 +244,6 @@ class TestConnectToDb(): # The default behaviour, normal in this case: assert 'Database name has not been passed' in m_ansible_module.warn_msg - def test_warn_db_default_non_default(self, m_ansible_module, monkeypatch, m_psycopg2): - """ - Test connect_to_db(), warn_db_default arg passed as False (by default is True). - """ - monkeypatch.setattr(pg, 'HAS_PSYCOPG2', True) - monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) - - db_connection = pg.connect_to_db(m_ansible_module, warn_db_default=False) - cursor = db_connection.cursor() - # if errors, db_connection returned as None: - assert isinstance(db_connection, DbConnection) - assert isinstance(cursor, Cursor) - assert m_ansible_module.err_msg == '' - assert m_ansible_module.warn_msg == '' - # pay attention that warn_db_defaul=True has been checked - # in the previous tests by - # assert('Database name has not been passed' in m_ansible_module.warn_msg) - # because of this is the default behavior - def test_fail_on_conn_true(self, m_ansible_module, monkeypatch, m_psycopg2): """ Test connect_to_db(), fail_on_conn arg passed as True (the default behavior). @@ -242,7 +253,8 @@ class TestConnectToDb(): m_ansible_module.params['login_user'] = 'Exception' # causes Exception - db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=True) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=True) assert 'unable to connect to database' in m_ansible_module.err_msg assert db_connection is None @@ -256,7 +268,8 @@ class TestConnectToDb(): m_ansible_module.params['login_user'] = 'Exception' # causes Exception - db_connection = pg.connect_to_db(m_ansible_module, fail_on_conn=False) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params, fail_on_conn=False) assert m_ansible_module.err_msg == '' assert 'PostgreSQL server is unavailable' in m_ansible_module.warn_msg @@ -271,7 +284,8 @@ class TestConnectToDb(): # case 1: psycopg2.__version >= 2.4.2 (the default in m_psycopg2) monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) - db_connection = pg.connect_to_db(m_ansible_module, autocommit=True) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True) cursor = db_connection.cursor() # if errors, db_connection returned as None: @@ -283,10 +297,26 @@ class TestConnectToDb(): m_psycopg2.__version__ = '2.4.1' monkeypatch.setattr(pg, 'psycopg2', m_psycopg2) - db_connection = pg.connect_to_db(m_ansible_module, autocommit=True) + conn_params = pg.get_conn_params(m_ansible_module, m_ansible_module.params) + db_connection = pg.connect_to_db(m_ansible_module, conn_params, autocommit=True) cursor = db_connection.cursor() # if errors, db_connection returned as None: assert isinstance(db_connection, DbConnection) assert isinstance(cursor, Cursor) assert 'psycopg2 must be at least 2.4.3' in m_ansible_module.err_msg + + +class TestGetConnParams(): + + """Namespace for testing get_conn_params() function.""" + + def test_get_conn_params_def(self, m_ansible_module): + """Test get_conn_params(), warn_db_default kwarg is default.""" + assert pg.get_conn_params(m_ansible_module, INPUT_DICT) == EXPECTED_DICT + assert m_ansible_module.warn_msg == 'Database name has not been passed, used default database to connect to.' + + def test_get_conn_params_warn_db_def_false(self, m_ansible_module): + """Test get_conn_params(), warn_db_default kwarg is False.""" + assert pg.get_conn_params(m_ansible_module, INPUT_DICT, warn_db_default=False) == EXPECTED_DICT + assert m_ansible_module.warn_msg == ''