aboutsummaryrefslogtreecommitdiff
path: root/test/lint/lint-include-guards.py
blob: 48b918e9dab6f9fa5f84bb767a81d3fd6ad3abb7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3
#
# Copyright (c) 2018-2022 The Bitcoin Core developers
# Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php.

"""
Check include guards.
"""

import re
import sys
from subprocess import check_output
from typing import List


HEADER_ID_PREFIX = 'BITCOIN_'
HEADER_ID_SUFFIX = '_H'

EXCLUDE_FILES_WITH_PREFIX = ['contrib/devtools/bitcoin-tidy',
                             'src/crypto/ctaes',
                             'src/leveldb',
                             'src/crc32c',
                             'src/secp256k1',
                             'src/minisketch',
                             'src/tinyformat.h',
                             'src/bench/nanobench.h',
                             'src/test/fuzz/FuzzedDataProvider.h']


def _get_header_file_lst() -> List[str]:
    """ Helper function to get a list of header filepaths to be
        checked for include guards.
    """
    git_cmd_lst = ['git', 'ls-files', '--', '*.h']
    header_file_lst = check_output(
        git_cmd_lst).decode('utf-8').splitlines()

    header_file_lst = [hf for hf in header_file_lst
                       if not any(ef in hf for ef
                                  in EXCLUDE_FILES_WITH_PREFIX)]

    return header_file_lst


def _get_header_id(header_file: str) -> str:
    """ Helper function to get the header id from a header file
        string.

        eg: 'src/wallet/walletdb.h' -> 'BITCOIN_WALLET_WALLETDB_H'

    Args:
        header_file: Filepath to header file.

    Returns:
        The header id.
    """
    header_id_base = header_file.split('/')[1:]
    header_id_base = '_'.join(header_id_base)
    header_id_base = header_id_base.replace('.h', '').replace('-', '_')
    header_id_base = header_id_base.upper()

    header_id = f'{HEADER_ID_PREFIX}{header_id_base}{HEADER_ID_SUFFIX}'

    return header_id


def main():
    exit_code = 0

    header_file_lst = _get_header_file_lst()
    for header_file in header_file_lst:
        header_id = _get_header_id(header_file)

        regex_pattern = f'^#(ifndef|define|endif //) {header_id}'

        with open(header_file, 'r', encoding='utf-8') as f:
            header_file_contents = f.readlines()

        count = 0
        for header_file_contents_string in header_file_contents:
            include_guard_lst = re.findall(
                regex_pattern, header_file_contents_string)

            count += len(include_guard_lst)

        if count != 3:
            print(f'{header_file} seems to be missing the expected '
                  'include guard:')
            print(f'  #ifndef {header_id}')
            print(f'  #define {header_id}')
            print('  ...')
            print(f'  #endif // {header_id}\n')
            exit_code = 1

    sys.exit(exit_code)


if __name__ == '__main__':
    main()