sg-tightener: Replace Permissive Security Group Rules with Evidence-Based CIDR Blocks

AWS Security Group Hardening Using VPC Flow Log Analysis: Introducing sg-tightener

👁42views

Replacing overly permissive security group rules becomes straightforward when actual traffic data drives the decisions. sg-tightener analyzes real network flow logs to identify which specific IP addresses genuinely communicate with your resources, then replaces broad CIDR blocks with precise, evidence-based rules that reflect observed behavior rather than assumptions, meaningfully reducing attack surface without disrupting legitimate traffic.

CloudScale AI SEO - Article Summary
  • 1.
    What it is
    sg-tightener is an open source AWS tool that reads 90 days of VPC flow logs to replace broad permissive security group CIDR rules with the tightest evidence-based CIDR blocks that actual traffic patterns support.
  • 2.
    Why it matters
    Manually auditing every security group to find legitimate source IPs is a forensic exercise that takes weeks, so sg-tightener automates that process entirely using a four-mode workflow with plan, apply, and revert steps that mirror Terraform.
  • 3.
    Key takeaway
    sg-tightener uses a gap-tolerance algorithm that accepts a CIDR block only if at most 20% of its address space was never observed in flow logs, preventing both over-permissive rules and accidental lockouts.
~17 min read

Andrew Baker, Group CIO, Capitec Bank

Most enterprises did not move to AWS. They extended into it. The datacenter did not go away. The VPN did not go away. The network team provisioned the Direct Connect, someone wrote a security group rule permitting the entire datacenter subnet, and that rule has been sitting there ever since, through every re-architecture, every team change, every compliance audit, silently granting 65,536 addresses the right to attempt a connection to your cloud workloads.

That is not a cloud security posture. That is a datacenter security posture with a better API.

The consequences are not theoretical. When a ransomware operator compromises a build agent on your corporate network, that /16 rule is the invitation into your AWS environment. When a vendor jump host gets taken over, your security groups already trust it. When an engineer provisions a “temporary” VPN path for a proof of concept three years ago and nobody decommissions it, that path exists in your security groups today as a permanent trusted network. The cloud boundary you believe you have is largely fictional because the rules that enforce it were written based on assumption rather than observation.

This post introduces sg-tightener, an open source tool that replaces assumption-based trust with evidence-based trust. Rather than asking engineers to guess at what CIDR ranges are legitimate, it reads your VPC flow logs, observes what IP addresses have actually connected to your resources, calculates the tightest CIDR blocks that cover those addresses without breaching AWS rule limits, and replaces your broad permissive rules with those empirically derived blocks. It is an extension of CloudToRepo, the open source tool for reverse-engineering AWS infrastructure into Terraform.

1. The Assumption Problem in Hybrid Cloud Networking

Enterprise security rules are almost universally assumption-based. A /16 rule exists because someone assumed the datacenter needed access. A /24 rule exists because someone assumed the application server subnet needed reach. Nobody went back to verify which addresses in those ranges ever actually connected, and nobody will, because doing so manually across hundreds of security groups is not a realistic proposition.

The result is that most enterprise AWS estates have a trust boundary that looks precise on a diagram and is almost meaningless in practice. Consider what lives inside a typical corporate /16: application servers that should reach your APIs, development workstations that absolutely should not, decommissioned servers with stale DNS records, build agents that may be externally facing, vendor jump hosts with their own security posture, monitoring infrastructure, test environments, and a long tail of machines nobody can fully account for. Your security group rule trusts all of them equally because they share a subnet.

Evidence-based security starts from the opposite position. Rather than asking “what should be trusted?”, it asks “what has been observed?” Flow logs tell you, with precision, which source IP addresses made accepted connections to which resources over any given period. That forensic record is the foundation sg-tightener builds on. The trust boundary it produces is not a guess. It is a compressed representation of your actual network behaviour.

2. Why Security Group References Alone Do Not Solve This

AWS frequently recommends using security group references instead of CIDR blocks: rather than permitting 10.0.0.0/16, you reference the security group attached to the resources that need access. That is good advice for traffic that originates entirely within AWS. It does not help with the problem this tool addresses.

Direct Connect and VPN connections terminate at a virtual gateway and arrive in your VPC as source IP addresses from your corporate address space. There is no AWS security group on the other end of that connection. The traffic arrives as a CIDR and must be trusted as a CIDR. Legacy enterprise networks are built on IP-based trust models that predate the concept of security group references by decades. Third-party integrations frequently terminate into fixed IP ranges that your vendors provide. Hybrid estates cannot fully eliminate CIDR-based trust because the other end of the connection is not in AWS.

This means the class of problem sg-tightener addresses is not going away as estates mature. As long as your AWS environment connects to anything outside AWS over a network path, you have CIDR-based trust rules that need to be as narrow as possible. The tool exists specifically for that boundary.

3. What Evidence-Based CIDR Reduction Looks Like in Practice

To make this concrete, consider a typical enterprise account that has been running for three years with a Direct Connect path from a corporate datacenter. The security group inventory shows a handful of rules like 10.0.0.0/16 and 10.4.0.0/16 granting broad inbound access to application tiers, databases, and internal APIs.

Running sg-tightener analyse over 90 days of VPC flow logs against that account produces a list of observed source IPs. The CIDR collapsing algorithm then computes the tightest covering blocks. A worked example of what that reduction looks like:

BeforeAfter
10.0.0.0/16 on port 443 (65,536 addresses trusted)10.0.10.0/27, 10.0.20.10/31 (34 addresses trusted)
10.4.0.0/16 on port 5432 (65,536 addresses trusted)10.4.8.0/29 (8 addresses trusted)
10.8.0.0/16 on port 8443 (65,536 addresses trusted)10.8.2.0/28, 10.8.16.0/28 (32 addresses trusted)

In this example the trusted address space drops from 196,608 addresses to 74. The lateral movement surface area reduces by over 99.9 percent. The security group rule count stays well within the 60 rule limit after CIDR collapsing. Nothing that was legitimately connecting before is locked out, because every address in the replacement blocks was observed in the flow logs.

That reduction is not achievable manually at any realistic pace. It requires the forensic analysis that flow logs make possible and that sg-tightener automates.

4. The Danger of Incomplete Observation Windows

This is the part of the tool that requires the most operational care, and it is worth addressing directly rather than burying in documentation.

The default analysis window is 90 days. That is long enough to capture most regular application traffic, batch jobs, and operational tooling. It is not long enough to capture everything in every environment, and if you run the tool without thinking about what your 90 day window might be missing, you risk locking out legitimate traffic.

The categories of traffic most likely to be absent from a 90 day window are the ones that matter most in a crisis. DR systems that only connect during failover tests, which may run quarterly or annually. Finance batch processes that run at month end or quarter end and may not have fired in the observation window. Blue green deployments where the old environment was inactive during the analysis period and comes back live after you have deployed. Maintenance hosts that only connect during planned maintenance windows. Vendor access paths that are used infrequently but are critical when needed.

The principle here is that absence of evidence is not evidence of absence. An IP address that did not connect during your observation window is not necessarily an IP address that should be blocked. It may simply be an IP address that did not happen to need access during that period.

sg-tightener handles this in several ways. The --days parameter lets you extend the window to 180 days or longer for environments where you know seasonal or infrequent traffic patterns exist. The diagnose script exists specifically for the failure mode where a legitimate source gets blocked after deployment: it scans REJECT entries in the flow logs over a configurable lookback window, surfaces IPs that are not covered by any current rule, and lets you add them to the approved list and re-apply immediately. The revert mode lets you restore the exact pre-deployment state within minutes if something goes wrong, and the apply mode is engineered so that if any single security group fails to update cleanly, the tool halts immediately and prints the revert command, so partial states cannot persist silently.

The operational recommendation is to run analyse over the longest window you have available, extend to 180 days for any account where you know DR or seasonal batch traffic exists, and plan deployments for periods when you have on-call coverage and can run the diagnose script immediately if connection failures appear. This is not a fire-and-forget tool. It is a forensic analysis tool with a safety net, and the safety net works best when someone is watching.

5. The CIDR Collapsing Algorithm

AWS security groups have a hard default limit of 60 inbound rules per group. If your account has been accessed by 200 distinct source IPs over three months, you cannot write 200 /32 rules. You also cannot write a /16 rule that covers all of them, because that reintroduces the permissiveness you are trying to eliminate. The algorithm has to find the middle ground: the smallest set of CIDR blocks that covers all observed addresses without exceeding the rule limit.

The approach sg-tightener uses works in three layers. The first layer computes, for every observed IP, the widest covering prefix where the gap fraction (fraction of addresses in the block that were not observed) stays within a user-specified tolerance. The default tolerance is 30 percent, meaning at most 30 percent of any block’s address space can consist of addresses that were never observed. The algorithm greedily prefers wider blocks over narrower ones when both satisfy the tolerance constraint, so densely-populated subnets collapse aggressively while sparse outliers remain as /32 rules.

The second layer enforces the rule count budget. If the first pass produces more rules than the security group can hold, the algorithm widens tolerance in 5 percent steps up to 95 percent, recomputing at each step. Each widening is logged with a warning so the operator can see exactly what trade-off was made. If even 95 percent tolerance does not bring the count within budget, the third layer kicks in: a force-fit pass that merges the closest pairs of blocks regardless of gap tolerance, choosing merges that introduce the smallest amount of new untrusted address space first. This guarantees the rule budget is met while remaining as evidence-aligned as possible, with a loud warning recommending a quota increase from AWS Support.

Independently of CIDR collapsing, the algorithm also handles port range merging when a single security group ends up with multiple port ranges across the same CIDR set after replacement rule construction. Adjacent or overlapping port ranges are merged into the smallest spanning range. The plan output explicitly flags any security group where port merging was required so the operator can review whether the broader port range is acceptable.

Critically, the per-group rule budget is computed from the actual current state of each security group, not from the global 60 rule limit. If a group has 25 rules that are not being touched (security group references, public 0.0.0.0/0 rules, and tight CIDRs at /24 or below), the budget for replacement rules is 60 minus 25 minus the count of permissive rules being removed. This prevents the failure mode where an apply succeeds in revoking but fails in authorising because the destination group cannot hold the new rules.

6. What the Tool Does and Does Not Touch

sg-tightener only modifies rules whose source CIDR is a fully contained subset of an RFC 1918 private block (10.0.0.0/8, 172.16.0.0/12, or 192.168.0.0/16) and whose prefix length is shorter than a configurable threshold (default /24). The containment check uses strict subset semantics rather than overlap semantics, so CIDRs like 192.0.0.0/4 that overlap with RFC 1918 ranges but are not themselves private are correctly excluded.

Rules with a source of 0.0.0.0/0 are left completely untouched regardless of port or protocol. A load balancer with port 443 open to the world stays exactly as it is. Rules that reference other security groups rather than CIDR blocks are left untouched. Rules already scoped at /24 or tighter are not modified. IPv6 rules are left untouched in this release.

Network ACLs are out of scope for the tightening workflow. The OU risk report scans and flags permissive NACL rules alongside security group rules, giving you visibility across both enforcement layers. NACL tightening is a planned phase two. NACLs are stateless and subnet-scoped, which means the blast radius of a misconfigured change covers every resource in the subnet with no per-resource fallback. The evidence-based approach is equally valid for NACLs but the implementation requires separate care and a lower default gap tolerance given the 20 rule limit NACLs impose.

7. Operating at Organisation Scale

sg-tightener is designed to operate across an AWS Organisation, not just a single account. The OU risk report traverses your entire OU tree, assumes a cross-account role in each active account, scans every region you specify in parallel, and produces a risk-ranked output sorted from the most permissive accounts to the least. That report is the starting point: it tells you where to spend your remediation effort first.

For the tightening workflow itself, the operational model that works best at scale depends on where your VPC flow logs are centralised. If you are running a Control Tower estate with a centralised logging account, you can point the analyse mode at the log groups or S3 buckets in that account and aggregate observed IPs across the entire organisation before building your CIDR list. If flow logs are account-local, you run the analyse and plan workflow per account and treat each account’s approved IP list independently.

The cross-account role assumption in the OU report uses the standard OrganizationAccountAccessRole by default, which is the role Control Tower provisions in every managed account. You can override this with --role-name if your organisation uses a different naming convention. The IAM permissions required are read-only for analyse, plan, and report modes, with write access to security group rules needed only for apply and revert. The tool supports the standard AWS partitions including aws-cn and aws-us-gov through standard boto3 partition resolution.

For organisations running delegated administrator accounts for security services, the recommended pattern is to run the OU report and per-account analyse jobs from a dedicated security tooling account that has the cross-account role in every member account, rather than running from the management account directly. This keeps the management account credentials out of the operational workflow and aligns with the principle of least privilege at the account level.

8. Installation and Prerequisites

The tool requires Python 3.9 or later and three external packages: boto3 and botocore for AWS API access, and pandas plus openpyxl for the Excel output in the OU risk report. Everything else is standard library. Create a virtual environment in your working directory and install the dependencies with the script below. You will also need an IAM role or profile with the permissions listed in the policy document. The read-only permissions are sufficient for analyse, plan, and report modes. Apply and revert additionally need the two write permissions.

cat > install.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail

python3 -m venv .venv
source .venv/bin/activate
pip install 'boto3>=1.34' 'botocore>=1.34' pandas openpyxl
echo "sg-tightener dependencies installed."
EOF
chmod +x install.sh
./install.sh

The IAM role used to run sg-tightener needs the following permissions. Analyse, plan, and report modes are entirely read-only. Apply and revert additionally need the two write permissions shown.

cat > sg_tightener_policy.json << 'EOF'
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "ReadOnly",
      "Effect": "Allow",
      "Action": [
        "ec2:DescribeSecurityGroups",
        "ec2:DescribeVpcs",
        "ec2:DescribeFlowLogs",
        "ec2:DescribeNetworkAcls",
        "ec2:DescribeAccountAttributes",
        "servicequotas:GetServiceQuota",
        "logs:DescribeLogGroups",
        "logs:DescribeLogStreams",
        "logs:FilterLogEvents",
        "logs:GetLogEvents",
        "s3:GetObject",
        "s3:ListBucket",
        "organizations:ListChildren",
        "organizations:DescribeAccount",
        "organizations:ListAccounts",
        "sts:AssumeRole"
      ],
      "Resource": "*"
    },
    {
      "Sid": "WriteSecurityGroups",
      "Effect": "Allow",
      "Action": [
        "ec2:AuthorizeSecurityGroupIngress",
        "ec2:RevokeSecurityGroupIngress"
      ],
      "Resource": "*"
    }
  ]
}
EOF
chmod 600 sg_tightener_policy.json

9. Main Tool: sg_tightener.py

This is the complete source for the main tool. It implements all four modes in a single file with no dependencies beyond the packages installed above. Drop it into your working directory and invoke it directly with Python. The module-level docstring doubles as the usage reference.

cat > sg_tightener.py << 'EOF'
#!/usr/bin/env python3
"""
sg-tightener: Replace permissive RFC 1918 security group rules with
evidence-based CIDR blocks derived from VPC flow log analysis.

An extension of CloudToRepo (cloudtorepo.com).

Modes:
    analyse  : Read flow logs, build IP list, write approved_ips.json
    plan     : Take approved IP list, produce changeset diff and plan file
    apply    : Execute a plan file, write backup, modify security groups
    revert   : Restore security groups from a backup file

Usage:
    python sg_tightener.py analyse --region af-south-1 --days 90
    python sg_tightener.py plan    --region af-south-1 --approved-ips approved_ips.json
    python sg_tightener.py apply   --region af-south-1 --plan sg_plan_<timestamp>.json
    python sg_tightener.py revert  --region af-south-1 --backup sg_backup_<timestamp>.json

Prerequisites:
    pip install 'boto3>=1.34' pandas openpyxl
"""

import boto3
import botocore
import json
import os
import sys
import argparse
import logging
import ipaddress
import re
import gzip
from collections import defaultdict
from datetime import datetime, timezone, timedelta
from typing import Optional

# Configure boto3 retry behaviour for the entire process. Adaptive mode handles
# throttling exponentially with jitter and respects RetryAfter hints.
BOTO_CONFIG = botocore.config.Config(
    retries={"max_attempts": 10, "mode": "adaptive"},
    connect_timeout=30,
    read_timeout=60
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
log = logging.getLogger(__name__)

DEFAULT_MAX_SG_RULES   = 60
DEFAULT_GAP_TOLERANCE  = 0.30
DEFAULT_DAYS           = 90
PERMISSIVE_PREFIX_LEN  = 24

PRIVATE_RANGES = [
    ipaddress.ip_network("10.0.0.0/8"),
    ipaddress.ip_network("172.16.0.0/12"),
    ipaddress.ip_network("192.168.0.0/16"),
]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def is_private(ip_str: str) -> bool:
    """Return True if ip_str is an IPv4 address inside RFC 1918."""
    if not ip_str:
        return False
    try:
        addr = ipaddress.ip_address(ip_str)
    except (ValueError, TypeError):
        return False
    if not isinstance(addr, ipaddress.IPv4Address):
        return False
    return any(addr in net for net in PRIVATE_RANGES)


def is_permissive_cidr(cidr_str: str) -> bool:
    """
    Return True if the CIDR is a private block large enough to be a concern.

    Uses strict subset_of containment rather than overlap, so non-RFC1918
    CIDRs that happen to overlap with private ranges (e.g. 192.0.0.0/4)
    are not flagged.
    """
    if not cidr_str or cidr_str in ("0.0.0.0/0", "::/0"):
        return False
    try:
        net = ipaddress.ip_network(cidr_str, strict=False)
    except ValueError:
        return False
    if not isinstance(net, ipaddress.IPv4Network):
        return False
    if not any(net.subnet_of(p) for p in PRIVATE_RANGES):
        return False
    return net.prefixlen < PERMISSIVE_PREFIX_LEN


def timestamp_str() -> str:
    return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")


def safe_write_json(data, path: str):
    """
    Atomically write JSON: write to .tmp, fsync, rename.
    Prevents partial backup files if the process is killed mid-write.
    """
    tmp = path + ".tmp"
    with open(tmp, "w") as f:
        json.dump(data, f, indent=2, default=str)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, path)


# ---------------------------------------------------------------------------
# CIDR collapsing algorithm
# ---------------------------------------------------------------------------

def collapse_ips_to_cidrs(
    ip_list: list,
    gap_tolerance: float = DEFAULT_GAP_TOLERANCE,
    max_rules: int = DEFAULT_MAX_SG_RULES,
    min_prefix_len: int = 16,
    force_fit: bool = True
) -> list:
    """
    Collapse a list of observed IPs into the smallest valid CIDR set.

    Algorithm:
      1. For each prefix length P (min_prefix_len .. 32), enumerate candidate
         supernets and accept those whose gap fraction is within tolerance.
      2. Greedy from widest (smallest P) to narrowest: claim each candidate
         for any uncovered IPs it contains. Wider blocks always preferred.
      3. If rule count exceeds max_rules, widen tolerance in 5% steps to 95%.
      4. If still over budget and force_fit=True, perform force-fit merging
         that ignores gap tolerance, with warnings.

    Returns: list of CIDR strings.
    """
    if not ip_list:
        return []

    addrs = sorted(set(
        ipaddress.ip_address(ip) for ip in ip_list if is_private(ip)
    ))
    if not addrs:
        return []

    observed_ints = set(int(a) for a in addrs)
    log.info("Collapsing %d unique observed IPs (tolerance=%.0f%%, max_rules=%d).",
             len(observed_ints), gap_tolerance * 100, max_rules)

    effective_tolerance = gap_tolerance
    cidrs               = _build_cidrs_at_tolerance(
        observed_ints, effective_tolerance, min_prefix_len
    )
    log.info("Initial collapse at %.0f%% tolerance: %d CIDR block(s).",
             effective_tolerance * 100, len(cidrs))

    while len(cidrs) > max_rules and effective_tolerance < 0.95:
        prev                = effective_tolerance
        effective_tolerance = min(0.95, effective_tolerance + 0.05)
        cidrs               = _build_cidrs_at_tolerance(
            observed_ints, effective_tolerance, min_prefix_len
        )
        log.warning(
            "Rule count %d exceeded limit %d at %.0f%% tolerance. "
            "Widened to %.0f%%, now have %d blocks.",
            len(cidrs), max_rules, prev * 100, effective_tolerance * 100, len(cidrs)
        )

    if effective_tolerance > gap_tolerance:
        log.warning(
            "Final tolerance %.0f%% (requested: %.0f%%). Review the output.",
            effective_tolerance * 100, gap_tolerance * 100
        )

    if len(cidrs) > max_rules and force_fit:
        log.warning(
            "Even at 95%% tolerance, rule count %d exceeds limit %d. "
            "Performing FORCE FIT: merging blocks ignoring gap tolerance. "
            "Review the resulting CIDR list carefully and consider requesting "
            "a security group quota increase from AWS Support.",
            len(cidrs), max_rules
        )
        cidrs = _force_fit(cidrs, max_rules, min_prefix_len)

    if len(cidrs) > max_rules:
        log.error(
            "Final rule count %d still exceeds limit %d. The tool cannot reduce "
            "further without exceeding tolerance. Manual intervention required.",
            len(cidrs), max_rules
        )

    log.info("Final CIDR set: %d block(s).", len(cidrs))
    return [str(c) for c in cidrs]


def _build_cidrs_at_tolerance(
    observed_ints: set,
    tolerance: float,
    min_prefix_len: int
) -> list:
    """
    Find acceptable supernets per prefix length and greedy-claim wider blocks first.
    """
    if not observed_ints:
        return []

    candidates = {}
    for p in range(min_prefix_len, 33):
        size = 1 << (32 - p)
        mask = ~(size - 1) & 0xFFFFFFFF
        buckets = defaultdict(set)
        for ip in observed_ints:
            buckets[ip & mask].add(ip)
        acceptable = {}
        for base, ips in buckets.items():
            gap = (size - len(ips)) / size
            if gap <= tolerance:
                acceptable[base] = ips
        candidates[p] = acceptable

    covered = set()
    chosen  = []
    for p in range(min_prefix_len, 33):
        acceptable = candidates[p]
        for base in sorted(acceptable.keys(), key=lambda b: -len(acceptable[b])):
            ips = acceptable[base]
            if ips - covered:
                chosen.append((p, base))
                covered |= ips

    for ip in observed_ints - covered:
        chosen.append((32, ip))

    return _networks_from_chosen(chosen)


def _force_fit(networks: list, max_rules: int, min_prefix_len: int) -> list:
    """
    Aggressively merge to fit max_rules, ignoring gap tolerance.

    Repeatedly finds the pair of adjacent networks whose smallest common
    supernet introduces the least new untrusted space, and merges them.
    """
    networks = sorted(networks, key=lambda n: int(n.network_address))
    iteration_cap = 1000  # safety bound against infinite loops

    while len(networks) > max_rules and iteration_cap > 0:
        iteration_cap -= 1
        best_idx = None
        best_cost = float("inf")
        best_supernet = None

        for i in range(len(networks) - 1):
            a, b   = networks[i], networks[i + 1]
            common = _smallest_common_supernet(a, b, min_prefix_len)
            if common is None:
                continue
            cost = common.num_addresses - a.num_addresses - b.num_addresses
            if cost < best_cost:
                best_cost     = cost
                best_idx      = i
                best_supernet = common

        if best_idx is None:
            log.warning("Force-fit: no further merges possible at min_prefix_len=%d.",
                        min_prefix_len)
            break

        new_list = []
        absorbed = False
        for n in networks:
            if n.subnet_of(best_supernet):
                if not absorbed:
                    new_list.append(best_supernet)
                    absorbed = True
            else:
                new_list.append(n)
        networks = list(ipaddress.collapse_addresses(new_list))

    return networks


def _smallest_common_supernet(a, b, min_prefix_len: int):
    """Find the smallest network containing both a and b, no smaller than min_prefix_len."""
    if a.subnet_of(b):
        return b
    if b.subnet_of(a):
        return a
    p = min(a.prefixlen, b.prefixlen)
    while p >= min_prefix_len:
        candidate = a.supernet(new_prefix=p)
        if b.subnet_of(candidate):
            return candidate
        p -= 1
    return None


def _networks_from_chosen(chosen: list) -> list:
    """Convert (prefix_len, base_int) tuples to deduplicated, collapsed networks."""
    networks = []
    for p, base in chosen:
        networks.append(ipaddress.ip_network(
            f"{ipaddress.IPv4Address(base)}/{p}", strict=False
        ))
    networks.sort(key=lambda n: (n.prefixlen, int(n.network_address)))
    deduped = []
    for n in networks:
        if any(n.subnet_of(other) and n != other for other in deduped):
            continue
        deduped.append(n)
    return list(ipaddress.collapse_addresses(deduped))


# ---------------------------------------------------------------------------
# Port range merging
# ---------------------------------------------------------------------------

def merge_port_ranges(rules: list, target_count: int) -> list:
    """
    Merge adjacent/overlapping port ranges on the same (cidr, protocol) key
    until total rule count fits target_count.
    """
    if len(rules) <= target_count:
        return rules

    log.warning(
        "Rule count %d exceeds target %d after CIDR collapsing. "
        "Merging port ranges.", len(rules), target_count
    )

    grouped = defaultdict(list)
    for rule in rules:
        key = (rule["cidr"], rule["protocol"])
        grouped[key].append((rule["from_port"], rule["to_port"]))

    merged_rules = []
    for (cidr, protocol), port_pairs in grouped.items():
        if protocol in ("-1", "all"):
            merged_rules.append({
                "cidr": cidr, "protocol": protocol,
                "from_port": 0, "to_port": 65535
            })
            continue

        sorted_pairs = sorted(set(port_pairs))
        merged_pairs = []
        cur_start, cur_end = sorted_pairs[0]
        for start, end in sorted_pairs[1:]:
            if start <= cur_end + 1:
                cur_end = max(cur_end, end)
            else:
                merged_pairs.append((cur_start, cur_end))
                cur_start, cur_end = start, end
        merged_pairs.append((cur_start, cur_end))

        for from_port, to_port in merged_pairs:
            merged_rules.append({
                "cidr": cidr, "protocol": protocol,
                "from_port": from_port, "to_port": to_port
            })

    if len(merged_rules) > target_count:
        log.warning(
            "After port merging, count is %d (target %d). "
            "Consider a security group rule quota increase.",
            len(merged_rules), target_count
        )

    return merged_rules


# ---------------------------------------------------------------------------
# Flow log reading
# ---------------------------------------------------------------------------

# Default flow log format v2 field positions
FLOW_LOG_V2_FIELDS = [
    "version", "account-id", "interface-id", "srcaddr", "dstaddr",
    "srcport", "dstport", "protocol", "packets", "bytes",
    "start", "end", "action", "log-status"
]


def parse_flow_log_format(format_string: Optional[str]) -> list:
    """
    Parse the LogFormat field from a flow log descriptor.

    Returns the ordered list of field names. Returns default v2 fields
    if format_string is empty or None.
    """
    if not format_string:
        return list(FLOW_LOG_V2_FIELDS)
    # AWS format uses ${field-name} tokens
    fields = re.findall(r"\$\{([^}]+)\}", format_string)
    return fields if fields else list(FLOW_LOG_V2_FIELDS)


def extract_ips_from_line(line: str, field_positions: dict, action_filter: str) -> Optional[str]:
    """
    Extract source IP from a single flow log line if the action matches.

    field_positions: dict of field_name -> 0-indexed position
    action_filter: 'ACCEPT' or 'REJECT'
    """
    parts = line.split()
    action_pos = field_positions.get("action")
    src_pos    = field_positions.get("srcaddr")
    if action_pos is None or src_pos is None:
        return None
    if len(parts) <= max(action_pos, src_pos):
        return None
    if parts[action_pos] != action_filter:
        return None
    src_ip = parts[src_pos]
    if src_ip and src_ip != "-" and is_private(src_ip):
        return src_ip
    return None


def check_flow_logs_enabled(ec2_client, vpc_ids: list, required_days: int) -> dict:
    """
    Verify VPC flow logs are enabled on all VPCs and have been running for
    at least required_days. Returns dict of vpc_id -> flow_log_descriptor.

    Raises RuntimeError with detailed remediation if any VPC fails the check.
    """
    if not vpc_ids:
        raise RuntimeError("No VPCs found in this region. Nothing to analyse.")

    try:
        fl_resp = ec2_client.describe_flow_logs(
            Filter=[{"Name": "resource-id", "Values": vpc_ids}]
        )
    except botocore.exceptions.ClientError as e:
        raise RuntimeError(f"Could not list flow logs: {e}")

    result = {}
    for fl in fl_resp.get("FlowLogs", []):
        if fl.get("FlowLogStatus") != "ACTIVE":
            log.warning("Flow log %s is not ACTIVE (status: %s); skipping.",
                        fl.get("FlowLogId"), fl.get("FlowLogStatus"))
            continue
        result[fl.get("ResourceId", "")] = fl

    missing = [v for v in vpc_ids if v not in result]
    if missing:
        raise RuntimeError(
            f"VPC flow logs are NOT enabled (or not ACTIVE) for: {missing}.\n"
            f"Please enable VPC flow logs for these VPCs and wait {required_days} days "
            f"before running sg-tightener analyse.\n"
            f"For environments with DR or seasonal traffic, consider --days 180."
        )

    for vpc_id, fl in result.items():
        creation = fl.get("CreationTime")
        if creation:
            age_days = (
                datetime.now(timezone.utc) - creation.replace(tzinfo=timezone.utc)
            ).days
            if age_days < required_days:
                raise RuntimeError(
                    f"Flow logs for {vpc_id} have only been enabled for {age_days} day(s). "
                    f"The analysis window requires {required_days} days of logs.\n"
                    f"Wait {required_days - age_days} more day(s) before re-running, "
                    f"or reduce --days to {age_days} to analyse the available window.\n"
                    f"WARNING: A shorter window increases the risk of locking out "
                    f"infrequently-used sources such as DR systems and batch jobs. "
                    f"Run sg_diagnose.py immediately after deployment if you "
                    f"proceed with a reduced window."
                )

    return result


def read_flow_logs_cloudwatch(
    logs_client,
    log_group_name: str,
    field_positions: dict,
    start_time: datetime,
    end_time: datetime
) -> set:
    """
    Read ACCEPT entries from a CloudWatch Logs flow log group.
    Returns a set of private source IP strings.
    """
    source_ips = set()
    start_ms   = int(start_time.timestamp() * 1000)
    end_ms     = int(end_time.timestamp() * 1000)
    log.info("Reading flow logs from CloudWatch group: %s", log_group_name)

    paginator   = logs_client.get_paginator("filter_log_events")
    page_count  = 0
    event_count = 0
    try:
        for page in paginator.paginate(
            logGroupName=log_group_name,
            startTime=start_ms,
            endTime=end_ms
        ):
            page_count += 1
            for event in page.get("events", []):
                event_count += 1
                ip = extract_ips_from_line(
                    event.get("message", ""), field_positions, "ACCEPT"
                )
                if ip:
                    source_ips.add(ip)
            if page_count % 10 == 0:
                log.info("  Processed %d pages, %d events, %d unique IPs so far.",
                         page_count, event_count, len(source_ips))
    except botocore.exceptions.ClientError as e:
        log.error("CloudWatch Logs read failed for %s: %s", log_group_name, e)
        raise

    log.info("Finished CloudWatch read: %d pages, %d events, %d unique private IPs.",
             page_count, event_count, len(source_ips))
    return source_ips


def read_flow_logs_s3(
    s3_client,
    bucket: str,
    prefix: str,
    field_positions: dict,
    start_time: datetime,
    end_time: datetime
) -> set:
    """
    Read ACCEPT entries from S3-backed flow logs, streaming gzipped content
    line-by-line to avoid loading entire archives into memory.
    """
    source_ips = set()
    log.info("Reading flow logs from S3: s3://%s/%s", bucket, prefix)

    paginator = s3_client.get_paginator("list_objects_v2")
    obj_count = 0
    try:
        for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
            for obj in page.get("Contents", []):
                key           = obj["Key"]
                # Filter on key path date components if available (lower lag than LastModified)
                key_date = _parse_date_from_s3_key(key)
                check_date = key_date or obj["LastModified"].replace(tzinfo=timezone.utc)
                # Allow files within window OR slightly outside (flow logs are delivered with lag)
                if check_date < start_time - timedelta(days=1):
                    continue
                if check_date > end_time + timedelta(days=1):
                    continue
                obj_count += 1
                try:
                    resp = s3_client.get_object(Bucket=bucket, Key=key)
                    body_stream = resp["Body"]
                    if key.endswith(".gz"):
                        # Stream-decompress to avoid loading the whole file
                        decompressed = gzip.GzipFile(fileobj=body_stream)
                        for raw_line in decompressed:
                            line = raw_line.decode("utf-8", errors="replace")
                            ip = extract_ips_from_line(line, field_positions, "ACCEPT")
                            if ip:
                                source_ips.add(ip)
                    else:
                        for raw_line in body_stream.iter_lines():
                            line = raw_line.decode("utf-8", errors="replace")
                            ip = extract_ips_from_line(line, field_positions, "ACCEPT")
                            if ip:
                                source_ips.add(ip)
                except Exception as e:
                    log.warning("Could not read S3 object %s: %s", key, e)
                if obj_count % 100 == 0:
                    log.info("  Processed %d S3 objects, %d unique IPs so far.",
                             obj_count, len(source_ips))
    except botocore.exceptions.ClientError as e:
        log.error("S3 read failed for bucket %s: %s", bucket, e)
        raise

    log.info("Finished S3 read: %d objects, %d unique private IPs.",
             obj_count, len(source_ips))
    return source_ips


def _parse_date_from_s3_key(key: str) -> Optional[datetime]:
    """
    Extract date from AWS flow log S3 key path: .../<yyyy>/<mm>/<dd>/...
    Returns a UTC datetime at midnight, or None if path doesn't match.
    """
    m = re.search(r"/(\d{4})/(\d{2})/(\d{2})/", key)
    if not m:
        return None
    try:
        return datetime(int(m.group(1)), int(m.group(2)), int(m.group(3)),
                        tzinfo=timezone.utc)
    except ValueError:
        return None


def parse_s3_destination(arn_or_path: str):
    """
    Parse a flow log S3 destination into (bucket, prefix).
    Supports aws, aws-cn, and aws-us-gov partitions.
    """
    # ARN format: arn:<partition>:s3:::bucket-name/prefix
    m = re.match(r"^arn:(aws|aws-cn|aws-us-gov):s3:::([^/]+)/?(.*)$", arn_or_path)
    if m:
        return m.group(2), m.group(3)
    # Path format: s3://bucket/prefix
    m = re.match(r"^s3://([^/]+)/?(.*)$", arn_or_path)
    if m:
        return m.group(1), m.group(2)
    return arn_or_path, ""


# ---------------------------------------------------------------------------
# Security group helpers
# ---------------------------------------------------------------------------

def describe_all_security_groups(ec2_client) -> list:
    """List every security group in the region with pagination."""
    groups    = []
    paginator = ec2_client.get_paginator("describe_security_groups")
    try:
        for page in paginator.paginate():
            groups.extend(page.get("SecurityGroups", []))
    except botocore.exceptions.ClientError as e:
        log.error("Could not list security groups: %s", e)
        raise
    log.info("Found %d security group(s) in region.", len(groups))
    return groups


def find_permissive_rules(sg: dict) -> list:
    """Identify inbound rules whose source CIDR is a large private block."""
    permissive = []
    for perm in sg.get("IpPermissions", []):
        for iprange in perm.get("IpRanges", []):
            cidr = iprange.get("CidrIp", "")
            if is_permissive_cidr(cidr):
                permissive.append({"permission": perm, "cidr": cidr})
    return permissive


def count_non_permissive_rules(sg: dict) -> int:
    """Count rules that are NOT being replaced (untouched in apply)."""
    total = 0
    for perm in sg.get("IpPermissions", []):
        # CIDR sources we keep: 0.0.0.0/0, tight /24+ CIDRs, IPv6
        for iprange in perm.get("IpRanges", []):
            if not is_permissive_cidr(iprange.get("CidrIp", "")):
                total += 1
        total += len(perm.get("Ipv6Ranges", []))
        total += len(perm.get("UserIdGroupPairs", []))
        total += len(perm.get("PrefixListIds", []))
    return total


def get_sg_rule_quota(ec2_client) -> int:
    """
    Get the actual per-security-group inbound rule limit.

    Falls back to DEFAULT_MAX_SG_RULES if the Service Quotas API is unavailable
    or the value can't be retrieved.
    """
    try:
        sq = boto3.client("service-quotas", region_name=ec2_client.meta.region_name,
                          config=BOTO_CONFIG)
        resp = sq.get_service_quota(
            ServiceCode="vpc",
            QuotaCode="L-0EA8095F"  # Inbound rules per security group
        )
        return int(resp["Quota"]["Value"])
    except Exception as e:
        log.info("Could not retrieve SG rule quota (using default %d): %s",
                 DEFAULT_MAX_SG_RULES, e)
        return DEFAULT_MAX_SG_RULES


def build_replacement_rules(
    permissive_rules: list,
    collapsed_cidrs: list,
    target_count: int
) -> tuple:
    """
    Build replacement rules from permissive rules and collapsed CIDRs.

    Returns (replacement_rules, port_merge_performed: bool).
    """
    if not collapsed_cidrs:
        return [], False

    flat_rules = []
    for entry in permissive_rules:
        perm      = entry["permission"]
        protocol  = perm.get("IpProtocol", "-1")
        from_port = perm.get("FromPort", 0)
        to_port   = perm.get("ToPort",   65535)
        for cidr in collapsed_cidrs:
            flat_rules.append({
                "cidr":      cidr,
                "protocol":  protocol,
                "from_port": from_port,
                "to_port":   to_port
            })

    # Deduplicate (cidr, protocol, from_port, to_port)
    seen = set()
    deduped = []
    for r in flat_rules:
        key = (r["cidr"], r["protocol"], r["from_port"], r["to_port"])
        if key not in seen:
            seen.add(key)
            deduped.append(r)

    pre_merge_count   = len(deduped)
    after_merge       = merge_port_ranges(deduped, target_count)
    port_merge_done   = len(after_merge) < pre_merge_count

    return after_merge, port_merge_done


def backup_security_groups(security_groups: list, backup_path: str):
    """Write a complete JSON snapshot atomically."""
    safe_write_json(security_groups, backup_path)
    log.info("Backup written: %s (%d groups, %d bytes)",
             backup_path, len(security_groups),
             os.path.getsize(backup_path))


# ---------------------------------------------------------------------------
# Plan summary
# ---------------------------------------------------------------------------

def print_plan_summary(plan: dict):
    changes        = plan.get("changes", [])
    sg_count       = len(changes)
    rule_removals  = sum(len(c["rules_to_remove"])   for c in changes)
    rule_additions = sum(len(c["replacement_rules"]) for c in changes)
    vpc_ids        = set(c.get("vpc_id", "") for c in changes if c.get("vpc_id"))
    port_merges    = sum(1 for c in changes if c.get("port_merge_required"))
    over_budget    = sum(1 for c in changes if c.get("budget_exceeded"))

    print()
    print("=" * 68)
    print("  sg-tightener PLAN SUMMARY")
    print("=" * 68)
    print(f"  Region                          : {plan.get('region', 'unknown')}")
    print(f"  Approved IP source              : {plan.get('approved_ips', 'unknown')}")
    print(f"  Gap tolerance                   : {plan.get('gap_tolerance', 0) * 100:.0f}%")
    print(f"  Security groups to modify       : {sg_count}")
    print(f"  VPCs affected                   : {len(vpc_ids)}")
    print(f"  Permissive rules removed        : {rule_removals}")
    print(f"  Replacement rules added         : {rule_additions}")
    if port_merges:
        print(f"  Groups requiring port merge     : {port_merges}  [review]")
    if over_budget:
        print(f"  Groups OVER RULE BUDGET         : {over_budget}  [BLOCKED]")
    print("=" * 68)
    print()

    for change in changes:
        marker = "  [BUDGET BLOCKED] " if change.get("budget_exceeded") else "  "
        port_marker = " [PORT MERGED]" if change.get("port_merge_required") else ""
        print(f"{marker}{change['sg_id']}  ({change.get('sg_name', '')})  "
              f"VPC: {change.get('vpc_id', '')}{port_marker}")
        for r in change["rules_to_remove"]:
            proto = r["permission"].get("IpProtocol", "?")
            fp    = r["permission"].get("FromPort", "all")
            tp    = r["permission"].get("ToPort",   "all")
            print(f"    REMOVE  {r['cidr']:22s}  proto={proto}  ports={fp}-{tp}")
        for r in change["replacement_rules"]:
            print(f"    ADD     {r['cidr']:22s}  proto={r['protocol']}  "
                  f"ports={r['from_port']}-{r['to_port']}")
        if change.get("budget_exceeded"):
            print(f"    !! Cannot apply: existing {change.get('existing_non_permissive', '?')} "
                  f"non-permissive rules + {len(change['replacement_rules'])} new rules "
                  f"would exceed quota.")
        print()


# ---------------------------------------------------------------------------
# Mode: analyse
# ---------------------------------------------------------------------------

def run_analyse(args):
    session    = boto3.Session(region_name=args.region)
    ec2        = session.client("ec2",  config=BOTO_CONFIG)
    logs_cl    = session.client("logs", config=BOTO_CONFIG)
    s3         = session.client("s3",   config=BOTO_CONFIG)

    end_time   = datetime.now(timezone.utc)
    start_time = end_time - timedelta(days=args.days)
    log.info("Analysis window: %s to %s (%d days)",
             start_time.date(), end_time.date(), args.days)

    try:
        vpc_resp = ec2.describe_vpcs()
    except botocore.exceptions.ClientError as e:
        log.error("Could not list VPCs: %s", e)
        sys.exit(1)

    vpcs = [v["VpcId"] for v in vpc_resp.get("Vpcs", [])]
    log.info("Found %d VPC(s): %s", len(vpcs), vpcs)

    try:
        fl_map = check_flow_logs_enabled(ec2, vpcs, args.days)
    except RuntimeError as e:
        log.error(str(e))
        sys.exit(1)

    all_ips = set()
    for vpc_id, fl in fl_map.items():
        log_format       = parse_flow_log_format(fl.get("LogFormat", ""))
        field_positions  = {name: i for i, name in enumerate(log_format)}
        missing_required = [f for f in ("srcaddr", "action") if f not in field_positions]
        if missing_required:
            log.error(
                "Flow log %s in VPC %s does not include required fields %s. "
                "Reconfigure the flow log with the default v2 format or "
                "include srcaddr and action.",
                fl.get("FlowLogId"), vpc_id, missing_required
            )
            sys.exit(1)
        log.info("Reading flow log for VPC %s (format: %s)", vpc_id,
                 ",".join(log_format))

        destination_type = fl.get("LogDestinationType", "cloud-watch-logs")
        try:
            if destination_type == "cloud-watch-logs":
                log_group = fl.get("LogGroupName")
                if not log_group:
                    log.warning("Flow log %s has no LogGroupName, skipping.",
                                fl.get("FlowLogId"))
                    continue
                ips = read_flow_logs_cloudwatch(
                    logs_cl, log_group, field_positions, start_time, end_time
                )
                all_ips.update(ips)
            elif destination_type == "s3":
                bucket, prefix = parse_s3_destination(fl.get("LogDestination", ""))
                ips = read_flow_logs_s3(
                    s3, bucket, prefix, field_positions, start_time, end_time
                )
                all_ips.update(ips)
            else:
                log.warning("Unsupported flow log destination type %s for %s, skipping.",
                            destination_type, vpc_id)
        except Exception as e:
            log.error("Failed to read flow logs for VPC %s: %s", vpc_id, e)
            sys.exit(1)

    sorted_ips = sorted(all_ips, key=lambda x: ipaddress.ip_address(x))
    output     = {
        "generated_at":     datetime.now(timezone.utc).isoformat(),
        "analysis_days":    args.days,
        "region":           args.region,
        "total_unique_ips": len(sorted_ips),
        "source_ips":       sorted_ips
    }

    out_path = args.output or f"approved_ips_{args.region}_{timestamp_str()}.json"
    safe_write_json(output, out_path)

    print(f"\nFound {len(sorted_ips)} unique private source IPs over {args.days} days.")
    print(f"IP list written to: {out_path}")
    print(f"\nReview the list carefully. Pay particular attention to any IPs that may")
    print(f"represent DR systems, batch jobs, or seasonal traffic not active during")
    print(f"this window. When satisfied, run:")
    print(f"  python sg_tightener.py plan --approved-ips {out_path} --region {args.region}")


# ---------------------------------------------------------------------------
# Mode: plan
# ---------------------------------------------------------------------------

def run_plan(args):
    if not os.path.exists(args.approved_ips):
        log.error("Approved IPs file does not exist: %s", args.approved_ips)
        sys.exit(1)

    with open(args.approved_ips) as f:
        try:
            ip_data = json.load(f)
        except json.JSONDecodeError as e:
            log.error("Approved IPs file is not valid JSON: %s", e)
            sys.exit(1)

    source_ips = ip_data.get("source_ips", [])
    if not source_ips:
        log.error("Approved IPs file contains no source_ips. Run analyse first.")
        sys.exit(1)
    log.info("Loaded %d approved source IPs from %s.", len(source_ips), args.approved_ips)

    session    = boto3.Session(region_name=args.region)
    ec2        = session.client("ec2", config=BOTO_CONFIG)

    # Determine actual per-group rule quota
    max_rules  = args.max_rules if args.max_rules else get_sg_rule_quota(ec2)
    log.info("Security group rule budget: %d rules per group.", max_rules)

    all_groups = describe_all_security_groups(ec2)

    changes = []
    for sg in all_groups:
        permissive = find_permissive_rules(sg)
        if not permissive:
            continue

        # Per-group budget: total quota minus the rules we are KEEPING
        existing_non_permissive = count_non_permissive_rules(sg)
        per_group_budget = max(1, max_rules - existing_non_permissive)
        log.info("SG %s (%s): %d permissive rules to replace, %d non-permissive to keep, "
                 "budget for replacement rules: %d",
                 sg["GroupId"], sg.get("GroupName", ""),
                 len(permissive), existing_non_permissive, per_group_budget)

        collapsed_cidrs = collapse_ips_to_cidrs(
            source_ips,
            gap_tolerance=args.gap_tolerance,
            max_rules=per_group_budget
        )

        replacement_rules, port_merge_done = build_replacement_rules(
            permissive, collapsed_cidrs, per_group_budget
        )

        budget_exceeded = len(replacement_rules) > per_group_budget

        changes.append({
            "sg_id":                   sg["GroupId"],
            "sg_name":                 sg.get("GroupName", ""),
            "vpc_id":                  sg.get("VpcId", ""),
            "rules_to_remove":         permissive,
            "replacement_rules":       replacement_rules,
            "port_merge_required":     port_merge_done,
            "existing_non_permissive": existing_non_permissive,
            "per_group_budget":        per_group_budget,
            "budget_exceeded":         budget_exceeded
        })

    plan = {
        "generated_at":  datetime.now(timezone.utc).isoformat(),
        "region":        args.region,
        "approved_ips":  args.approved_ips,
        "gap_tolerance": args.gap_tolerance,
        "max_rules":     max_rules,
        "changes":       changes
    }

    print_plan_summary(plan)

    plan_path = args.output or f"sg_plan_{args.region}_{timestamp_str()}.json"
    safe_write_json(plan, plan_path)
    log.info("Plan written: %s", plan_path)

    blocked = sum(1 for c in changes if c.get("budget_exceeded"))
    if blocked:
        print(f"\nWARNING: {blocked} security group(s) are BUDGET BLOCKED and will be ")
        print(f"skipped during apply. Resolve by removing other rules from those groups,")
        print(f"increasing the security group rule quota via AWS Support, or raising")
        print(f"--gap-tolerance to collapse more aggressively.")

    print(f"Plan saved to: {plan_path}")
    print(f"To apply: python sg_tightener.py apply --plan {plan_path} --region {args.region}")


# ---------------------------------------------------------------------------
# Mode: apply
# ---------------------------------------------------------------------------

def run_apply(args):
    if not os.path.exists(args.plan):
        log.error("Plan file does not exist: %s", args.plan)
        sys.exit(1)

    with open(args.plan) as f:
        try:
            plan = json.load(f)
        except json.JSONDecodeError as e:
            log.error("Plan file is not valid JSON: %s", e)
            sys.exit(1)

    changes = plan.get("changes", [])
    if not changes:
        print("Plan contains no changes. Nothing to do.")
        return

    # Filter out budget-blocked groups
    applicable_changes = [c for c in changes if not c.get("budget_exceeded")]
    blocked_changes    = [c for c in changes if c.get("budget_exceeded")]

    if blocked_changes:
        log.warning("%d security group(s) are budget-blocked and will be SKIPPED:",
                    len(blocked_changes))
        for c in blocked_changes:
            log.warning("  %s (%s)", c["sg_id"], c.get("sg_name", ""))

    if not applicable_changes:
        print("All changes are budget-blocked. Nothing to apply.")
        return

    # Show only applicable changes in the summary for the apply step
    summary_plan = dict(plan)
    summary_plan["changes"] = applicable_changes
    print_plan_summary(summary_plan)

    if not args.yes:
        confirm = input(
            f"About to modify {len(applicable_changes)} security group(s). "
            "Type 'apply' to proceed, or anything else to abort: "
        ).strip()
        if confirm != "apply":
            print("Aborted.")
            sys.exit(0)

    session = boto3.Session(region_name=args.region)
    ec2     = session.client("ec2", config=BOTO_CONFIG)

    # ====================================================================
    # CRITICAL: Take a complete backup of EVERY security group before any change
    # ====================================================================
    sg_ids_to_backup = [c["sg_id"] for c in applicable_changes]
    log.info("Taking pre-apply backup of %d security group(s)...", len(sg_ids_to_backup))

    current_groups = []
    backup_failures = []
    # Use a single describe_security_groups call per chunk (max 200 IDs)
    chunk_size = 100
    for i in range(0, len(sg_ids_to_backup), chunk_size):
        chunk = sg_ids_to_backup[i:i + chunk_size]
        try:
            resp = ec2.describe_security_groups(GroupIds=chunk)
            current_groups.extend(resp.get("SecurityGroups", []))
        except botocore.exceptions.ClientError as e:
            log.error("Bulk describe failed for chunk; retrying individually: %s", e)
            for sg_id in chunk:
                try:
                    r = ec2.describe_security_groups(GroupIds=[sg_id])
                    current_groups.extend(r.get("SecurityGroups", []))
                except Exception as ex:
                    log.error("Could not describe %s for backup: %s", sg_id, ex)
                    backup_failures.append(sg_id)

    if backup_failures:
        log.error(
            "FAILED to backup %d security group(s): %s. "
            "REFUSING to proceed: any change to these groups would be unrecoverable.",
            len(backup_failures), backup_failures
        )
        sys.exit(1)

    if len(current_groups) != len(sg_ids_to_backup):
        log.error(
            "Backup count mismatch: expected %d groups, got %d. "
            "REFUSING to proceed.",
            len(sg_ids_to_backup), len(current_groups)
        )
        sys.exit(1)

    backup_path = f"sg_backup_{args.region}_{timestamp_str()}.json"
    backup_security_groups(current_groups, backup_path)
    print(f"\nBackup saved to: {backup_path}")
    print(f"To revert at any time:")
    print(f"  python sg_tightener.py revert --backup {backup_path} --region {args.region}\n")

    # Build a lookup so apply can verify current rule state per group
    backup_lookup = {sg["GroupId"]: sg for sg in current_groups}

    modified_count = 0
    for change in applicable_changes:
        sg_id             = change["sg_id"]
        sg_name           = change.get("sg_name", "")
        rules_to_remove   = change["rules_to_remove"]
        replacement_rules = change["replacement_rules"]

        if not rules_to_remove:
            log.info("  %s (%s): no permissive rules to remove, skipping.", sg_id, sg_name)
            continue

        log.info("  Modifying %s (%s): removing %d rules, adding %d rules",
                 sg_id, sg_name, len(rules_to_remove), len(replacement_rules))
        print(f"  Modifying {sg_id} ({sg_name})...")

        # Idempotency: skip rules that no longer exist in the current group state
        backup_sg     = backup_lookup.get(sg_id, {})
        current_perms = backup_sg.get("IpPermissions", [])
        existing_keys = set()
        for perm in current_perms:
            proto = perm.get("IpProtocol")
            fp    = perm.get("FromPort", -1)
            tp    = perm.get("ToPort", -1)
            for ipr in perm.get("IpRanges", []):
                existing_keys.add((proto, fp, tp, ipr.get("CidrIp")))

        revoke_permissions = []
        for entry in rules_to_remove:
            perm  = entry["permission"]
            proto = perm.get("IpProtocol")
            fp    = perm.get("FromPort", -1)
            tp    = perm.get("ToPort", -1)
            key   = (proto, fp, tp, entry["cidr"])
            if key not in existing_keys:
                log.info("    Skipping already-removed rule %s on %s", entry["cidr"], sg_id)
                continue
            new_perm = {
                "IpProtocol": proto,
                "IpRanges":   [{"CidrIp": entry["cidr"]}]
            }
            if proto not in ("-1", "all"):
                if fp is not None:
                    new_perm["FromPort"] = fp
                if tp is not None:
                    new_perm["ToPort"] = tp
            revoke_permissions.append(new_perm)

        if revoke_permissions:
            try:
                ec2.revoke_security_group_ingress(
                    GroupId=sg_id, IpPermissions=revoke_permissions
                )
                log.info("    Revoked %d rule(s) from %s", len(revoke_permissions), sg_id)
            except botocore.exceptions.ClientError as e:
                if e.response["Error"]["Code"] == "InvalidPermission.NotFound":
                    log.warning("    Some rules to revoke not found in %s (already gone?). "
                                "Proceeding.", sg_id)
                else:
                    log.error("    FAILED to revoke rules from %s: %s", sg_id, e)
                    log.error("    HALTING. Backup: %s", backup_path)
                    log.error("    To revert NOW: python sg_tightener.py revert "
                              "--backup %s --region %s --yes",
                              backup_path, args.region)
                    sys.exit(2)

        new_permissions = []
        for rule in replacement_rules:
            perm = {
                "IpProtocol": rule["protocol"],
                "IpRanges":   [{"CidrIp": rule["cidr"],
                                "Description": "sg-tightener managed"}]
            }
            if rule["protocol"] not in ("-1", "all"):
                perm["FromPort"] = rule["from_port"]
                perm["ToPort"]   = rule["to_port"]
            new_permissions.append(perm)

        if new_permissions:
            try:
                ec2.authorize_security_group_ingress(
                    GroupId=sg_id, IpPermissions=new_permissions
                )
                log.info("    Added %d replacement rule(s) to %s",
                         len(new_permissions), sg_id)
                modified_count += 1
            except botocore.exceptions.ClientError as e:
                if e.response["Error"]["Code"] == "InvalidPermission.Duplicate":
                    log.warning("    Some replacement rules already exist on %s. "
                                "Proceeding.", sg_id)
                    modified_count += 1
                else:
                    log.error("    FAILED to authorise rules on %s: %s", sg_id, e)
                    log.error("    PARTIAL APPLY. Backup: %s", backup_path)
                    log.error("    To revert NOW: python sg_tightener.py revert "
                              "--backup %s --region %s --yes",
                              backup_path, args.region)
                    sys.exit(2)
        else:
            # No replacement rules. Revocation happened, leaving no replacement.
            # This is correct ONLY if the operator intended to fully remove access.
            log.warning("    %s: no replacement rules. Group may have no ingress now.",
                        sg_id)

    print(f"\nModified {modified_count} security group(s).")
    print(f"Backup: {backup_path}")
    print(f"\nMonitor for REJECT entries in your flow logs over the next 24-48 hours.")
    print(f"If connection failures appear, run sg_diagnose.py to identify missing IPs.")
    if blocked_changes:
        print(f"\nNote: {len(blocked_changes)} group(s) were skipped due to rule budget.")


# ---------------------------------------------------------------------------
# Mode: revert
# ---------------------------------------------------------------------------

def run_revert(args):
    if not os.path.exists(args.backup):
        log.error("Backup file does not exist: %s", args.backup)
        sys.exit(1)

    with open(args.backup) as f:
        try:
            backed_up_groups = json.load(f)
        except json.JSONDecodeError as e:
            log.error("Backup file is not valid JSON: %s", e)
            sys.exit(1)

    if not backed_up_groups:
        log.error("Backup file is empty.")
        sys.exit(1)

    log.info("Loaded backup: %d security group(s) from %s",
             len(backed_up_groups), args.backup)

    print(f"\nThis will restore {len(backed_up_groups)} security group(s) "
          f"from backup: {args.backup}")

    if not args.yes:
        confirm = input("Type 'revert' to proceed, or anything else to abort: ").strip()
        if confirm != "revert":
            print("Aborted.")
            sys.exit(0)

    session = boto3.Session(region_name=args.region)
    ec2     = session.client("ec2", config=BOTO_CONFIG)

    failures = []
    for sg in backed_up_groups:
        sg_id                = sg["GroupId"]
        sg_name              = sg.get("GroupName", "")
        original_permissions = sg.get("IpPermissions", [])

        print(f"  Restoring {sg_id} ({sg_name})...")

        # Get current rules in order to revoke them
        try:
            current = ec2.describe_security_groups(GroupIds=[sg_id])
            if not current.get("SecurityGroups"):
                log.error("    %s: not found in account. Skipping.", sg_id)
                failures.append((sg_id, "not found"))
                continue
            current_perms = current["SecurityGroups"][0].get("IpPermissions", [])
        except botocore.exceptions.ClientError as e:
            log.error("    Could not describe %s: %s", sg_id, e)
            failures.append((sg_id, str(e)))
            continue

        # CRITICAL: First add the original rules back, then revoke the current ones.
        # If we revoke first and authorise fails, the group is empty.
        # Adding first risks duplicate errors but those are handled gracefully.

        if original_permissions:
            try:
                ec2.authorize_security_group_ingress(
                    GroupId=sg_id, IpPermissions=original_permissions
                )
                log.info("    Restored %d original rule(s) on %s",
                         len(original_permissions), sg_id)
            except botocore.exceptions.ClientError as e:
                code = e.response["Error"]["Code"]
                if code == "InvalidPermission.Duplicate":
                    log.info("    Some original rules already present on %s (idempotent).",
                             sg_id)
                else:
                    log.error("    FAILED to authorise original rules on %s: %s. "
                              "ABORTING revert of this group; current rules NOT removed.",
                              sg_id, e)
                    failures.append((sg_id, str(e)))
                    continue

        # Now revoke any current rules that are not in the original set
        original_keys = set()
        for perm in original_permissions:
            proto = perm.get("IpProtocol")
            fp    = perm.get("FromPort", -1)
            tp    = perm.get("ToPort", -1)
            for ipr in perm.get("IpRanges", []):
                original_keys.add((proto, fp, tp, "cidr", ipr.get("CidrIp")))
            for ipr in perm.get("Ipv6Ranges", []):
                original_keys.add((proto, fp, tp, "cidr6", ipr.get("CidrIpv6")))
            for ref in perm.get("UserIdGroupPairs", []):
                original_keys.add((proto, fp, tp, "sg", ref.get("GroupId")))
            for pl in perm.get("PrefixListIds", []):
                original_keys.add((proto, fp, tp, "pl", pl.get("PrefixListId")))

        to_revoke = []
        for perm in current_perms:
            proto = perm.get("IpProtocol")
            fp    = perm.get("FromPort", -1)
            tp    = perm.get("ToPort", -1)
            rebuild = {"IpProtocol": proto}
            if proto not in ("-1", "all"):
                if fp is not None:
                    rebuild["FromPort"] = fp
                if tp is not None:
                    rebuild["ToPort"] = tp
            extra_ranges = []
            for ipr in perm.get("IpRanges", []):
                if (proto, fp, tp, "cidr", ipr.get("CidrIp")) not in original_keys:
                    extra_ranges.append({"CidrIp": ipr.get("CidrIp")})
            if extra_ranges:
                copy = dict(rebuild)
                copy["IpRanges"] = extra_ranges
                to_revoke.append(copy)

        if to_revoke:
            try:
                ec2.revoke_security_group_ingress(
                    GroupId=sg_id, IpPermissions=to_revoke
                )
                log.info("    Revoked %d non-original rule(s) from %s", len(to_revoke), sg_id)
            except botocore.exceptions.ClientError as e:
                log.error("    Could not revoke non-original rules from %s: %s. "
                          "Original rules WERE restored but current rules remain.",
                          sg_id, e)
                failures.append((sg_id, str(e)))

    if failures:
        print("\nRevert completed with %d failure(s):" % len(failures))
        for sg_id, reason in failures:
            print(f"  {sg_id}: {reason}")
        print("\nReview the security groups above manually.")
        sys.exit(3)

    print("\nRevert complete. All security groups restored.")


# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------

def parse_args():
    parser = argparse.ArgumentParser(
        description="sg-tightener: Replace permissive security group rules with "
                    "evidence-based CIDR blocks from VPC flow log analysis."
    )
    sub = parser.add_subparsers(dest="mode", required=True)

    p_analyse = sub.add_parser("analyse", help="Read flow logs, build approved IP list")
    p_analyse.add_argument("--region", required=True)
    p_analyse.add_argument("--days", type=int, default=DEFAULT_DAYS,
                           help=f"Analysis window in days (default: {DEFAULT_DAYS}). "
                                f"Use 180+ for environments with DR or seasonal traffic.")
    p_analyse.add_argument("--output", help="Output JSON path (default: auto-named)")

    p_plan = sub.add_parser("plan", help="Produce changeset from approved IP list")
    p_plan.add_argument("--region",        required=True)
    p_plan.add_argument("--approved-ips",  required=True)
    p_plan.add_argument("--gap-tolerance", type=float, default=DEFAULT_GAP_TOLERANCE,
                        help=f"Max fraction of CIDR block that may be unobserved "
                             f"(default: {DEFAULT_GAP_TOLERANCE})")
    p_plan.add_argument("--max-rules", type=int,
                        help="Override per-group rule limit (default: detect from quota)")
    p_plan.add_argument("--output", help="Plan output path (default: auto-named)")

    p_apply = sub.add_parser("apply", help="Execute a plan file")
    p_apply.add_argument("--region", required=True)
    p_apply.add_argument("--plan",   required=True)
    p_apply.add_argument("--yes",    action="store_true")

    p_revert = sub.add_parser("revert", help="Restore from backup")
    p_revert.add_argument("--region", required=True)
    p_revert.add_argument("--backup", required=True)
    p_revert.add_argument("--yes",    action="store_true")

    return parser.parse_args()


def main():
    args     = parse_args()
    dispatch = {
        "analyse": run_analyse,
        "plan":    run_plan,
        "apply":   run_apply,
        "revert":  run_revert,
    }
    try:
        dispatch[args.mode](args)
    except KeyboardInterrupt:
        log.error("Interrupted by user.")
        sys.exit(130)


if __name__ == "__main__":
    main()
EOF
chmod +x sg_tightener.py

10. Diagnose Script: sg_diagnose.py

Run this after a DR event, a failover, or any time a new service goes live and you start seeing connection failures. It reads REJECT entries from your flow logs, identifies private source IPs not covered by any existing security group rule, lets you review them, merges them into the approved IP list, and optionally applies the updated rules immediately.

cat > sg_diagnose.py << 'EOF'
#!/usr/bin/env python3
"""
sg-diagnose: Surface rejected private source IPs from VPC flow logs and
add them to the approved IP list, then optionally re-apply rules.

Usage:
    python sg_diagnose.py --region af-south-1 --hours 4
    python sg_diagnose.py --region af-south-1 --hours 4 --apply
"""

import boto3
import json
import os
import sys
import argparse
import logging
import ipaddress
from datetime import datetime, timezone, timedelta

from sg_tightener import (
    is_private,
    collapse_ips_to_cidrs,
    describe_all_security_groups,
    find_permissive_rules,
    count_non_permissive_rules,
    get_sg_rule_quota,
    build_replacement_rules,
    backup_security_groups,
    timestamp_str,
    print_plan_summary,
    safe_write_json,
    parse_flow_log_format,
    extract_ips_from_line,
    BOTO_CONFIG,
    DEFAULT_GAP_TOLERANCE,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
log = logging.getLogger(__name__)


def read_rejected_ips(logs_client, log_group_name, field_positions,
                      start_time, end_time):
    """Read REJECT entries and return list of (src_ip, dst_port) tuples."""
    rejected = []
    start_ms = int(start_time.timestamp() * 1000)
    end_ms   = int(end_time.timestamp() * 1000)
    log.info("Scanning REJECT entries in: %s", log_group_name)

    paginator = logs_client.get_paginator("filter_log_events")
    page_count = 0
    for page in paginator.paginate(
        logGroupName=log_group_name,
        startTime=start_ms,
        endTime=end_ms
    ):
        page_count += 1
        for event in page.get("events", []):
            message = event.get("message", "")
            ip = extract_ips_from_line(message, field_positions, "REJECT")
            if ip:
                parts = message.split()
                dst_port_pos = field_positions.get("dstport")
                dst_port = parts[dst_port_pos] if dst_port_pos is not None and \
                           dst_port_pos < len(parts) else "?"
                rejected.append((ip, dst_port))
        if page_count % 10 == 0:
            log.info("  Processed %d pages, %d REJECT entries so far.",
                     page_count, len(rejected))

    log.info("Total REJECT entries from private IPs: %d", len(rejected))
    return rejected


def ip_covered_by_existing_rules(ip, security_groups):
    """Check whether ip is covered by any inbound CIDR rule in security_groups."""
    try:
        addr = ipaddress.ip_address(ip)
    except ValueError:
        return False
    for sg in security_groups:
        for perm in sg.get("IpPermissions", []):
            for iprange in perm.get("IpRanges", []):
                cidr = iprange.get("CidrIp", "")
                if not cidr:
                    continue
                try:
                    if addr in ipaddress.ip_network(cidr, strict=False):
                        return True
                except ValueError:
                    continue
    return False


def parse_args():
    parser = argparse.ArgumentParser(
        description="sg-diagnose: Find rejected private IPs in flow logs."
    )
    parser.add_argument("--region", required=True)
    parser.add_argument("--hours", type=int, default=4,
                        help="Lookback window in hours (default: 4)")
    parser.add_argument("--approved-ips",
                        help="Existing approved_ips JSON to merge into")
    parser.add_argument("--gap-tolerance", type=float, default=DEFAULT_GAP_TOLERANCE)
    parser.add_argument("--apply", action="store_true",
                        help="Apply updated rules after review")
    parser.add_argument("--yes", action="store_true")
    return parser.parse_args()


def main():
    args = parse_args()

    if args.approved_ips and not os.path.exists(args.approved_ips):
        log.error("Approved IPs file does not exist: %s", args.approved_ips)
        sys.exit(1)

    session = boto3.Session(region_name=args.region)
    ec2     = session.client("ec2",  config=BOTO_CONFIG)
    logs_cl = session.client("logs", config=BOTO_CONFIG)

    end_time   = datetime.now(timezone.utc)
    start_time = end_time - timedelta(hours=args.hours)
    log.info("Scanning REJECT entries from %s to %s (%d hours)",
             start_time.strftime("%Y-%m-%d %H:%M"),
             end_time.strftime("%Y-%m-%d %H:%M"), args.hours)

    try:
        vpcs = [v["VpcId"] for v in ec2.describe_vpcs().get("Vpcs", [])]
    except Exception as e:
        log.error("Could not list VPCs: %s", e)
        sys.exit(1)

    try:
        fl_map = ec2.describe_flow_logs(
            Filter=[{"Name": "resource-id", "Values": vpcs}]
        ).get("FlowLogs", [])
    except Exception as e:
        log.error("Could not list flow logs: %s", e)
        sys.exit(1)

    if not fl_map:
        log.error("No VPC flow logs found. Enable flow logs first.")
        sys.exit(1)

    all_rejected = []
    for fl in fl_map:
        if fl.get("FlowLogStatus") != "ACTIVE":
            continue
        log_format      = parse_flow_log_format(fl.get("LogFormat", ""))
        field_positions = {n: i for i, n in enumerate(log_format)}
        log_group       = fl.get("LogGroupName")
        if not log_group:
            log.info("Skipping non-CloudWatch flow log %s.", fl.get("FlowLogId"))
            continue
        try:
            all_rejected.extend(read_rejected_ips(
                logs_cl, log_group, field_positions, start_time, end_time
            ))
        except Exception as e:
            log.error("Failed reading %s: %s", log_group, e)

    if not all_rejected:
        print(f"\nNo REJECT entries found for private IPs in the last {args.hours} hour(s).")
        print("If you expected failures, verify flow logs are publishing and the")
        print("lookback window is wide enough.")
        return

    all_groups = describe_all_security_groups(ec2)
    unique_new = sorted(
        set(
            ip for ip, _ in all_rejected
            if not ip_covered_by_existing_rules(ip, all_groups)
        ),
        key=lambda x: ipaddress.ip_address(x)
    )

    print(f"\nFound {len(all_rejected)} REJECT entries from private IPs.")
    print(f"{len(unique_new)} IP(s) are NOT covered by any existing rule:\n")
    for ip in unique_new:
        ports = sorted(set(p for i, p in all_rejected if i == ip))
        print(f"  {ip:20s}  destination ports: {', '.join(ports)}")

    if not unique_new:
        print("\nAll rejected IPs are already covered. Check rule ordering or protocol.")
        return

    existing_ips = []
    if args.approved_ips:
        try:
            with open(args.approved_ips) as f:
                existing_ips = json.load(f).get("source_ips", [])
        except (json.JSONDecodeError, FileNotFoundError) as e:
            log.warning("Could not read approved IPs file (%s). Starting fresh.", e)

    merged_ips = sorted(
        set(existing_ips) | set(unique_new),
        key=lambda x: ipaddress.ip_address(x)
    )

    out_path = args.approved_ips or f"approved_ips_{args.region}_{timestamp_str()}.json"
    safe_write_json({
        "generated_at":       datetime.now(timezone.utc).isoformat(),
        "region":             args.region,
        "total_unique_ips":   len(merged_ips),
        "source_ips":         merged_ips,
        "diagnose_additions": unique_new
    }, out_path)

    print(f"\nUpdated approved IP list ({len(merged_ips)} total) written to: {out_path}")

    if not args.apply:
        print(f"\nTo apply: run plan and apply manually with the updated file.")
        return

    # Apply: same logic as the main tool's plan + apply
    max_rules = get_sg_rule_quota(ec2)

    groups_to_modify = [
        (sg, find_permissive_rules(sg))
        for sg in all_groups
        if find_permissive_rules(sg)
    ]

    if not groups_to_modify:
        print("\nNo permissive rules found to update.")
        return

    changes = []
    for sg, permissive in groups_to_modify:
        existing_non_permissive = count_non_permissive_rules(sg)
        budget                  = max(1, max_rules - existing_non_permissive)
        collapsed_cidrs         = collapse_ips_to_cidrs(
            merged_ips, gap_tolerance=args.gap_tolerance, max_rules=budget
        )
        replacement_rules, port_merge_done = build_replacement_rules(
            permissive, collapsed_cidrs, budget
        )
        changes.append({
            "sg_id":                   sg["GroupId"],
            "sg_name":                 sg.get("GroupName", ""),
            "vpc_id":                  sg.get("VpcId", ""),
            "rules_to_remove":         permissive,
            "replacement_rules":       replacement_rules,
            "port_merge_required":     port_merge_done,
            "existing_non_permissive": existing_non_permissive,
            "per_group_budget":        budget,
            "budget_exceeded":         len(replacement_rules) > budget
        })

    plan = {
        "region":        args.region,
        "approved_ips":  out_path,
        "gap_tolerance": args.gap_tolerance,
        "max_rules":     max_rules,
        "changes":       changes
    }
    print_plan_summary(plan)

    backup_path = f"sg_backup_diagnose_{args.region}_{timestamp_str()}.json"
    backup_security_groups([sg for sg, _ in groups_to_modify], backup_path)
    print(f"Backup saved to: {backup_path}")

    if not args.yes:
        confirm = input("Type 'apply' to proceed: ").strip()
        if confirm != "apply":
            print("Aborted. Updated approved_ips file has been saved.")
            return

    # Reuse apply logic by writing plan and invoking via subprocess? No - inline it.
    # For simplicity here, write the plan file and tell the user to apply.
    plan_path = f"sg_plan_diagnose_{args.region}_{timestamp_str()}.json"
    safe_write_json(plan, plan_path)
    print(f"\nPlan saved to: {plan_path}")
    print(f"Run: python sg_tightener.py apply --plan {plan_path} --region {args.region} --yes")


if __name__ == "__main__":
    main()
EOF
chmod +x sg_diagnose.py

11. OU Risk Report: sg_ou_report.py

This script scans an entire AWS Organisation or a specified account list and produces a risk-ranked report. NACLs are included in the report, clearly labelled, even though the tightening workflow only operates on security groups. The report exits with code 1 if any CRITICAL findings are present, making it usable as a pipeline gate.

cat > sg_ou_report.py << 'EOF'
#!/usr/bin/env python3
"""
sg-ou-report: Risk-rank AWS accounts by security group and NACL permissiveness.
"""

import boto3
import botocore
import csv
import sys
import argparse
import logging
import ipaddress
from datetime import datetime, timezone
from dataclasses import dataclass, asdict
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

BOTO_CONFIG = botocore.config.Config(
    retries={"max_attempts": 10, "mode": "adaptive"}
)

try:
    import pandas as pd
    PANDAS_AVAILABLE = True
except ImportError:
    PANDAS_AVAILABLE = False

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
log = logging.getLogger(__name__)

PRIVATE_RANGES = [
    ipaddress.ip_network("10.0.0.0/8"),
    ipaddress.ip_network("172.16.0.0/12"),
    ipaddress.ip_network("192.168.0.0/16"),
]

DEFAULT_MAX_PREFIX_LEN = 24
DEFAULT_ROLE_NAME      = "OrganizationAccountAccessRole"


def classify_severity(prefix_len: int, threshold: int) -> str:
    gap = threshold - prefix_len
    if gap >= 16:
        return "CRITICAL"
    elif gap >= 8:
        return "HIGH"
    elif gap >= 4:
        return "MEDIUM"
    return "LOW"


def is_private_subset(net) -> bool:
    """True only if the network is fully contained within an RFC 1918 range."""
    return any(net.subnet_of(p) for p in PRIVATE_RANGES)


@dataclass
class Finding:
    account_id:     str
    account_name:   str
    region:         str
    resource_type:  str
    resource_id:    str
    resource_name:  str
    vpc_id:         str
    rule_number:    str
    rule_protocol:  str
    rule_from_port: str
    rule_to_port:   str
    cidr:           str
    prefix_len:     int
    severity:       str
    recommendation: str


def list_accounts_in_ou(ou_id: str) -> list:
    org      = boto3.client("organizations", config=BOTO_CONFIG)
    accounts = []

    def recurse(parent_id):
        paginator = org.get_paginator("list_children")
        try:
            for page in paginator.paginate(ParentId=parent_id, ChildType="ACCOUNT"):
                for child in page["Children"]:
                    try:
                        acc = org.describe_account(AccountId=child["Id"])["Account"]
                        if acc["Status"] == "ACTIVE":
                            accounts.append({"id": acc["Id"], "name": acc["Name"]})
                    except botocore.exceptions.ClientError as e:
                        log.warning("Could not describe account %s: %s",
                                    child["Id"], e)
            for page in paginator.paginate(ParentId=parent_id,
                                            ChildType="ORGANIZATIONAL_UNIT"):
                for child in page["Children"]:
                    recurse(child["Id"])
        except botocore.exceptions.ClientError as e:
            log.error("Could not list children of %s: %s", parent_id, e)

    recurse(ou_id)
    return accounts


def get_session(account_id: str, role_name: Optional[str]) -> boto3.Session:
    if not role_name:
        return boto3.Session()
    sts      = boto3.client("sts", config=BOTO_CONFIG)
    role_arn = f"arn:aws:iam::{account_id}:role/{role_name}"
    creds    = sts.assume_role(
        RoleArn=role_arn, RoleSessionName="SGPermissivenessAudit"
    )["Credentials"]
    return boto3.Session(
        aws_access_key_id=creds["AccessKeyId"],
        aws_secret_access_key=creds["SecretAccessKey"],
        aws_session_token=creds["SessionToken"]
    )


def scan_security_groups(ec2, account_id, account_name, region, max_prefix_len):
    findings  = []
    paginator = ec2.get_paginator("describe_security_groups")
    try:
        for page in paginator.paginate():
            for sg in page.get("SecurityGroups", []):
                for perm in sg.get("IpPermissions", []):
                    protocol  = perm.get("IpProtocol", "-1")
                    from_port = str(perm.get("FromPort", "all"))
                    to_port   = str(perm.get("ToPort", "all"))
                    for iprange in perm.get("IpRanges", []):
                        cidr = iprange.get("CidrIp", "")
                        if not cidr or cidr == "0.0.0.0/0":
                            continue
                        try:
                            net = ipaddress.ip_network(cidr, strict=False)
                        except ValueError:
                            continue
                        if not is_private_subset(net):
                            continue
                        if net.prefixlen >= max_prefix_len:
                            continue
                        findings.append(Finding(
                            account_id=account_id, account_name=account_name,
                            region=region, resource_type="SecurityGroup",
                            resource_id=sg["GroupId"],
                            resource_name=sg.get("GroupName", ""),
                            vpc_id=sg.get("VpcId", ""),
                            rule_number="",
                            rule_protocol=protocol,
                            rule_from_port=from_port, rule_to_port=to_port,
                            cidr=cidr, prefix_len=net.prefixlen,
                            severity=classify_severity(net.prefixlen, max_prefix_len),
                            recommendation=(
                                f"Replace /{net.prefixlen} with tightest covering CIDR "
                                f"from flow log analysis. Use sg-tightener."
                            )
                        ))
    except botocore.exceptions.ClientError as e:
        log.error("describe_security_groups failed for %s/%s: %s",
                  account_id, region, e)
    return findings


def scan_nacls(ec2, account_id, account_name, region, max_prefix_len):
    findings  = []
    paginator = ec2.get_paginator("describe_network_acls")
    try:
        for page in paginator.paginate():
            for nacl in page.get("NetworkAcls", []):
                nacl_id   = nacl["NetworkAclId"]
                nacl_name = next(
                    (t["Value"] for t in nacl.get("Tags", [])
                     if t.get("Key") == "Name"),
                    nacl_id
                )
                # Sort entries by RuleNumber to surface the effective order
                entries = sorted(
                    nacl.get("Entries", []),
                    key=lambda e: e.get("RuleNumber", 99999)
                )
                for entry in entries:
                    if entry.get("Egress", True):
                        continue
                    if entry.get("RuleAction", "") != "allow":
                        continue
                    cidr = entry.get("CidrBlock", "")
                    if not cidr or cidr == "0.0.0.0/0":
                        continue
                    try:
                        net = ipaddress.ip_network(cidr, strict=False)
                    except ValueError:
                        continue
                    if not is_private_subset(net):
                        continue
                    if net.prefixlen >= max_prefix_len:
                        continue
                    port_range = entry.get("PortRange", {})
                    findings.append(Finding(
                        account_id=account_id, account_name=account_name,
                        region=region, resource_type="NACL",
                        resource_id=nacl_id, resource_name=nacl_name,
                        vpc_id=nacl.get("VpcId", ""),
                        rule_number=str(entry.get("RuleNumber", "?")),
                        rule_protocol=str(entry.get("Protocol", "-1")),
                        rule_from_port=str(port_range.get("From", "all")),
                        rule_to_port=str(port_range.get("To", "all")),
                        cidr=cidr, prefix_len=net.prefixlen,
                        severity=classify_severity(net.prefixlen, max_prefix_len),
                        recommendation=(
                            f"NACL rule {entry.get('RuleNumber')} permits /{net.prefixlen}. "
                            f"Review manually with rule ordering in mind. "
                            f"NACL tightening is not automated."
                        )
                    ))
    except botocore.exceptions.ClientError as e:
        log.error("describe_network_acls failed for %s/%s: %s",
                  account_id, region, e)
    return findings


def scan_account(account, role_name, regions, max_prefix_len):
    account_id   = account["id"]
    account_name = account["name"]
    log.info("Scanning account %s (%s)", account_id, account_name)

    try:
        session = get_session(account_id, role_name)
    except botocore.exceptions.ClientError as e:
        log.error("Cannot assume role in %s: %s", account_id, e)
        return []
    except Exception as e:
        log.error("Unexpected error assuming role in %s: %s", account_id, e)
        return []

    findings = []
    for region in regions:
        try:
            ec2 = session.client("ec2", region_name=region, config=BOTO_CONFIG)
            findings.extend(scan_security_groups(ec2, account_id, account_name,
                                                  region, max_prefix_len))
            findings.extend(scan_nacls(ec2, account_id, account_name,
                                        region, max_prefix_len))
        except Exception as e:
            log.error("Error scanning %s/%s: %s", account_id, region, e)

    log.info("Account %s: %d finding(s).", account_id, len(findings))
    return findings


def write_csv(findings, path):
    if not findings:
        return
    fieldnames = list(asdict(findings[0]).keys())
    with open(path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for finding in findings:
            writer.writerow(asdict(finding))
    log.info("CSV written: %s", path)


def write_excel(findings, path):
    if not PANDAS_AVAILABLE or not findings:
        return
    from openpyxl.styles import PatternFill

    df = pd.DataFrame([asdict(f) for f in findings])
    df["_sort"] = df["severity"].map(
        {"CRITICAL": 0, "HIGH": 1, "MEDIUM": 2, "LOW": 3}
    ).fillna(99)
    df = df.sort_values(
        ["_sort", "account_id", "resource_type", "prefix_len"]
    ).drop(columns=["_sort"])

    colour_map = {
        "CRITICAL": "FF4444", "HIGH": "FF8800",
        "MEDIUM": "FFD700",   "LOW": "90EE90"
    }

    with pd.ExcelWriter(path, engine="openpyxl") as writer:
        df.to_excel(writer, index=False, sheet_name="Permissive Rules")
        ws      = writer.sheets["Permissive Rules"]
        sev_col = list(df.columns).index("severity") + 1
        for row_idx, row in enumerate(df.itertuples(index=False), start=2):
            colour = colour_map.get(row.severity, "FFFFFF")
            ws.cell(row=row_idx, column=sev_col).fill = PatternFill(
                start_color=colour, end_color=colour, fill_type="solid"
            )

        acct_summary = (
            df.groupby(["account_id", "account_name"])
              .agg(
                  total_findings=("resource_id", "count"),
                  sg_findings=("resource_type",
                               lambda x: (x == "SecurityGroup").sum()),
                  nacl_findings=("resource_type",
                                  lambda x: (x == "NACL").sum()),
                  critical=("severity", lambda x: (x == "CRITICAL").sum()),
                  high=("severity",     lambda x: (x == "HIGH").sum()),
                  medium=("severity",   lambda x: (x == "MEDIUM").sum()),
                  low=("severity",      lambda x: (x == "LOW").sum()),
                  most_permissive_prefix=("prefix_len", "min")
              )
              .reset_index()
              .sort_values("total_findings", ascending=False)
        )
        acct_summary.to_excel(writer, index=False, sheet_name="Account Risk Ranking")

    log.info("Excel written: %s", path)


def print_summary(findings, max_prefix_len):
    if not findings:
        print("\nNo permissive rules found.")
        return

    from collections import defaultdict
    by_account = defaultdict(list)
    for f in findings:
        by_account[(f.account_id, f.account_name)].append(f)
    sorted_accounts = sorted(by_account.items(),
                              key=lambda x: len(x[1]), reverse=True)

    print("\n" + "=" * 80)
    print(f"  SG + NACL PERMISSIVENESS REPORT  (threshold: /{max_prefix_len})")
    print("=" * 80)
    print(f"  {'ACCOUNT ID':15s}  {'ACCOUNT NAME':28s}  "
          f"{'SG':5s}  {'NACL':5s}  {'TOTAL':6s}  WORST")
    print(f"  {'-'*15}  {'-'*28}  {'-'*5}  {'-'*5}  {'-'*6}  {'-'*14}")

    for (account_id, account_name), account_findings in sorted_accounts:
        sg_count   = sum(1 for f in account_findings
                          if f.resource_type == "SecurityGroup")
        nacl_count = sum(1 for f in account_findings if f.resource_type == "NACL")
        worst_pfx  = min(f.prefix_len for f in account_findings)
        worst_sev  = max(account_findings,
                          key=lambda x: {"CRITICAL": 4, "HIGH": 3,
                                          "MEDIUM": 2, "LOW": 1}.get(x.severity, 0)
                          ).severity
        print(f"  {account_id:15s}  {account_name[:28]:28s}  "
              f"{sg_count:5d}  {nacl_count:5d}  {len(account_findings):6d}  "
              f"/{worst_pfx} ({worst_sev})")

    sevs       = {"CRITICAL": 0, "HIGH": 0, "MEDIUM": 0, "LOW": 0}
    sg_total   = 0
    nacl_total = 0
    for f in findings:
        sevs[f.severity] += 1
        if f.resource_type == "SecurityGroup":
            sg_total += 1
        else:
            nacl_total += 1

    print()
    print(f"  Accounts with findings        : {len(by_account)}")
    print(f"  Security group violations     : {sg_total}")
    print(f"  NACL violations (report only) : {nacl_total}")
    print(f"  Total violations              : {len(findings)}")
    for sev in ["CRITICAL", "HIGH", "MEDIUM", "LOW"]:
        print(f"    {sev:10s}: {sevs[sev]:4d}")
    print("=" * 80)


def parse_args():
    parser = argparse.ArgumentParser(
        description="sg-ou-report: Risk-rank AWS accounts by permissiveness."
    )
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--ou-id",    help="AWS Organizations OU ID")
    group.add_argument("--accounts", nargs="+")
    parser.add_argument("--role-name",      default=DEFAULT_ROLE_NAME)
    parser.add_argument("--regions",        nargs="+",
                        default=["af-south-1", "eu-west-1", "us-east-1"])
    parser.add_argument("--max-prefix-len", type=int, default=DEFAULT_MAX_PREFIX_LEN)
    parser.add_argument("--workers",        type=int, default=5)
    parser.add_argument("--output-prefix",  default="sg_permissiveness_report")
    return parser.parse_args()


def main():
    args      = parse_args()
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")

    if args.ou_id:
        log.info("Discovering accounts in OU: %s", args.ou_id)
        accounts = list_accounts_in_ou(args.ou_id)
        log.info("Found %d active account(s).", len(accounts))
    else:
        accounts = [{"id": a, "name": a} for a in args.accounts]

    if not accounts:
        log.error("No accounts to scan.")
        sys.exit(1)

    all_findings = []
    with ThreadPoolExecutor(max_workers=args.workers) as executor:
        futures = {
            executor.submit(scan_account, acc, args.role_name,
                            args.regions, args.max_prefix_len): acc
            for acc in accounts
        }
        for future in as_completed(futures):
            acc = futures[future]
            try:
                all_findings.extend(future.result())
            except Exception as e:
                log.error("Account %s failed: %s", acc["id"], e)

    print_summary(all_findings, args.max_prefix_len)

    csv_path  = f"{args.output_prefix}_{timestamp}.csv"
    xlsx_path = f"{args.output_prefix}_{timestamp}.xlsx"
    write_csv(all_findings, csv_path)
    write_excel(all_findings, xlsx_path)

    log.info("Report complete. CSV: %s  Excel: %s", csv_path, xlsx_path)
    return 1 if any(f.severity == "CRITICAL" for f in all_findings) else 0


if __name__ == "__main__":
    sys.exit(main())
EOF
chmod +x sg_ou_report.py

12. Example Workflow

These wrapper scripts encode the standard sequence: run the OU report first to understand where you have the highest concentration of permissive rules, then run analyse and plan against the accounts you want to tighten, review the plan output carefully, apply it, and keep the diagnose script on hand for the 48 hours afterwards. Replace the region, OU ID, role name, and timestamp placeholders with your actual values. The scripts assume you have run ./install.sh first and your AWS credentials are in your environment or profile.

cat > run_ou_report.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_ou_report.py \
  --ou-id ou-xxxx-xxxxxxxx \
  --role-name YourAuditRole \
  --regions af-south-1 eu-west-1 \
  --max-prefix-len 24 \
  --workers 8
EOF
chmod +x run_ou_report.sh
cat > run_analyse.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_tightener.py analyse \
  --region af-south-1 \
  --days 90
EOF
chmod +x run_analyse.sh
cat > run_plan.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_tightener.py plan \
  --region af-south-1 \
  --approved-ips approved_ips_af-south-1_<timestamp>.json \
  --gap-tolerance 0.30
EOF
chmod +x run_plan.sh
cat > run_apply.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_tightener.py apply \
  --region af-south-1 \
  --plan sg_plan_af-south-1_<timestamp>.json
EOF
chmod +x run_apply.sh
cat > run_revert.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_tightener.py revert \
  --region af-south-1 \
  --backup sg_backup_af-south-1_<timestamp>.json
EOF
chmod +x run_revert.sh
cat > run_diagnose.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate

python sg_diagnose.py \
  --region af-south-1 \
  --hours 4 \
  --approved-ips approved_ips_af-south-1_<timestamp>.json \
  --apply
EOF
chmod +x run_diagnose.sh

13. Test Suite

The CIDR collapsing algorithm has enough edge cases that a regression test suite is essential for future changes. The tests below cover every bug identified during the original code review plus the algorithm’s correctness boundaries. Drop this file in the same directory as sg_tightener.py and run it with python sg_tightener_test.py to confirm the tool behaves correctly before any deployment. Any future fix to the algorithm or parsing logic should keep this suite passing.

cat > sg_tightener_test.py << 'EOF'
#!/usr/bin/env python3
"""
sg-tightener test suite. Validates the CIDR collapsing algorithm, helper
functions, and parsing logic against the bugs identified during code review.

Run from the same directory as sg_tightener.py:

    python sg_tightener_test.py

Exit code 0 if all tests pass, 1 otherwise.
"""

import ipaddress
import logging
import random
import sys
import unittest

# Silence the tool's own logging during tests
logging.getLogger("sg_tightener").setLevel(logging.CRITICAL)

from sg_tightener import (
    collapse_ips_to_cidrs,
    is_private,
    is_permissive_cidr,
    parse_flow_log_format,
    extract_ips_from_line,
    parse_s3_destination,
    merge_port_ranges,
    count_non_permissive_rules,
    find_permissive_rules,
    build_replacement_rules,
    _parse_date_from_s3_key,
    safe_write_json,
)
from datetime import datetime, timezone
import json
import os
import tempfile


def cover_check(ips, cidrs):
    """Assert every private IP is covered by at least one CIDR."""
    nets = [ipaddress.ip_network(c) for c in cidrs]
    for ip in ips:
        if not is_private(ip):
            continue
        addr = ipaddress.ip_address(ip)
        if not any(addr in n for n in nets):
            raise AssertionError(f"{ip} not covered by {cidrs}")


class TestPrivateAddressClassification(unittest.TestCase):
    """is_private and is_permissive_cidr must use strict containment."""

    def test_rfc1918_addresses_are_private(self):
        self.assertTrue(is_private("10.0.0.1"))
        self.assertTrue(is_private("172.16.5.10"))
        self.assertTrue(is_private("192.168.1.1"))

    def test_public_addresses_are_not_private(self):
        self.assertFalse(is_private("8.8.8.8"))
        self.assertFalse(is_private("1.1.1.1"))
        self.assertFalse(is_private("172.32.0.1"))  # outside 172.16/12

    def test_empty_and_garbage(self):
        self.assertFalse(is_private(""))
        self.assertFalse(is_private("not-an-ip"))
        self.assertFalse(is_private(None))

    def test_permissive_cidr_uses_subset_not_overlap(self):
        # 192.0.0.0/4 overlaps 192.168.0.0/16 but is NOT private
        self.assertFalse(is_permissive_cidr("192.0.0.0/4"))
        # 10.0.0.0/16 is fully contained in 10.0.0.0/8 and broad enough
        self.assertTrue(is_permissive_cidr("10.0.0.0/16"))
        self.assertTrue(is_permissive_cidr("172.16.0.0/16"))

    def test_permissive_cidr_excludes_tight_blocks(self):
        self.assertFalse(is_permissive_cidr("10.0.0.0/24"))
        self.assertFalse(is_permissive_cidr("10.0.0.0/32"))

    def test_permissive_cidr_excludes_public_blocks(self):
        self.assertFalse(is_permissive_cidr("0.0.0.0/0"))
        self.assertFalse(is_permissive_cidr("8.0.0.0/8"))
        self.assertFalse(is_permissive_cidr(""))


class TestCidrCollapsing(unittest.TestCase):
    """The CIDR collapsing algorithm must actually collapse IPs."""

    def test_empty_input(self):
        self.assertEqual(collapse_ips_to_cidrs([]), [])

    def test_single_ip(self):
        self.assertEqual(collapse_ips_to_cidrs(["10.0.1.5"]), ["10.0.1.5/32"])

    def test_two_adjacent_ips_collapse_to_slash_31(self):
        r = collapse_ips_to_cidrs(["10.0.1.0", "10.0.1.1"])
        self.assertEqual(r, ["10.0.1.0/31"])

    def test_sixteen_consecutive_ips_collapse_to_slash_28(self):
        """Bug 1 regression: this used to stay as 16 separate /32s."""
        ips = [f"10.0.1.{i}" for i in range(0, 16)]
        r = collapse_ips_to_cidrs(ips)
        self.assertEqual(r, ["10.0.1.0/28"])
        cover_check(ips, r)

    def test_199_ips_in_slash_24_collapse_heavily(self):
        """Bug 1 regression: this used to stay as 199 /32s."""
        ips = [f"10.0.1.{i}" for i in range(1, 200)]
        r = collapse_ips_to_cidrs(ips)
        self.assertLessEqual(len(r), 4, f"Expected <= 4 blocks, got {r}")
        cover_check(ips, r)

    def test_200_random_ips_fit_within_budget(self):
        random.seed(42)
        ips = [
            f"10.{random.randint(0, 5)}.{random.randint(0, 255)}.{random.randint(1, 254)}"
            for _ in range(200)
        ]
        r = collapse_ips_to_cidrs(ips, max_rules=60)
        self.assertLessEqual(len(r), 60)
        cover_check(ips, r)

    def test_sparse_ips_stay_as_slash_32(self):
        """Non-clustered IPs across different /24s should remain /32."""
        ips = ["10.0.1.1", "10.0.5.50", "10.0.100.200"]
        r = collapse_ips_to_cidrs(ips)
        self.assertEqual(set(r), {"10.0.1.1/32", "10.0.5.50/32", "10.0.100.200/32"})

    def test_public_ips_are_dropped(self):
        ips = ["10.0.1.5", "8.8.8.8", "192.168.1.1"]
        r = collapse_ips_to_cidrs(ips)
        nets = [ipaddress.ip_network(c) for c in r]
        # 8.8.8.8 must not be in any output network
        self.assertFalse(any(ipaddress.ip_address("8.8.8.8") in n for n in nets))
        # 10.0.1.5 and 192.168.1.1 must be covered
        self.assertTrue(any(ipaddress.ip_address("10.0.1.5") in n for n in nets))
        self.assertTrue(any(ipaddress.ip_address("192.168.1.1") in n for n in nets))

    def test_force_fit_guarantees_budget(self):
        """When tolerance cannot reduce enough, force fit must hit budget."""
        ips = [f"10.0.{i}.{j}" for i in range(0, 20) for j in [1, 50, 100, 200]]
        r = collapse_ips_to_cidrs(ips, max_rules=5)
        self.assertLessEqual(len(r), 5)
        cover_check(ips, r)

    def test_default_tolerance_rejects_high_gap(self):
        """At 30% tolerance, an 11-IP /28 (gap 31%) should NOT collapse to /28."""
        ips = [f"10.0.1.{i}" for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
        r = collapse_ips_to_cidrs(ips)
        self.assertNotIn("10.0.1.0/28", r)
        cover_check(ips, r)

    def test_high_density_collapses_to_wider_block(self):
        """At 30% tolerance, 199 IPs in /24 (gap 22%) must collapse to /24."""
        ips = [f"10.0.1.{i}" for i in range(1, 200)]
        r = collapse_ips_to_cidrs(ips)
        # The wider block could be /24 or split into /25 + smaller; either is acceptable
        # so long as the rule count is small.
        self.assertLessEqual(len(r), 4)


class TestFlowLogFormatParsing(unittest.TestCase):
    """Custom log formats must be parsed and used correctly."""

    def test_default_v2_format(self):
        fields = parse_flow_log_format("")
        self.assertEqual(fields[3], "srcaddr")
        self.assertEqual(fields[12], "action")

    def test_default_v2_format_from_none(self):
        fields = parse_flow_log_format(None)
        self.assertEqual(fields[3], "srcaddr")

    def test_custom_format_parsing(self):
        custom = "${srcaddr} ${dstaddr} ${action} ${instance-id}"
        fields = parse_flow_log_format(custom)
        self.assertEqual(fields, ["srcaddr", "dstaddr", "action", "instance-id"])

    def test_extract_ip_from_default_format(self):
        positions = {n: i for i, n in enumerate(parse_flow_log_format(""))}
        line = "2 123456789 eni-abc 10.0.1.5 10.0.1.10 12345 443 6 1 100 1000000 1000060 ACCEPT OK"
        ip = extract_ips_from_line(line, positions, "ACCEPT")
        self.assertEqual(ip, "10.0.1.5")

    def test_extract_ip_from_custom_format(self):
        positions = {"srcaddr": 0, "dstaddr": 1, "action": 2, "instance-id": 3}
        ip = extract_ips_from_line("10.0.1.5 10.0.1.10 ACCEPT i-12345", positions, "ACCEPT")
        self.assertEqual(ip, "10.0.1.5")

    def test_action_filter_excludes_other_actions(self):
        positions = {"srcaddr": 0, "dstaddr": 1, "action": 2, "instance-id": 3}
        ip = extract_ips_from_line("10.0.1.5 10.0.1.10 ACCEPT i-12345", positions, "REJECT")
        self.assertIsNone(ip)

    def test_public_source_ip_not_extracted(self):
        positions = {"srcaddr": 0, "dstaddr": 1, "action": 2, "instance-id": 3}
        ip = extract_ips_from_line("8.8.8.8 10.0.1.10 ACCEPT i-12345", positions, "ACCEPT")
        self.assertIsNone(ip)

    def test_dash_source_ip_not_extracted(self):
        positions = {"srcaddr": 0, "dstaddr": 1, "action": 2}
        ip = extract_ips_from_line("- 10.0.1.10 ACCEPT", positions, "ACCEPT")
        self.assertIsNone(ip)

    def test_short_line_handled_safely(self):
        positions = {"srcaddr": 0, "dstaddr": 1, "action": 2, "instance-id": 3}
        ip = extract_ips_from_line("incomplete", positions, "ACCEPT")
        self.assertIsNone(ip)


class TestS3DestinationParsing(unittest.TestCase):
    """All AWS partitions must parse correctly."""

    def test_aws_partition(self):
        b, p = parse_s3_destination("arn:aws:s3:::my-bucket/flow-logs/")
        self.assertEqual((b, p), ("my-bucket", "flow-logs/"))

    def test_aws_cn_partition(self):
        b, p = parse_s3_destination("arn:aws-cn:s3:::cn-bucket/path")
        self.assertEqual((b, p), ("cn-bucket", "path"))

    def test_aws_us_gov_partition(self):
        b, p = parse_s3_destination("arn:aws-us-gov:s3:::gov-bucket/")
        self.assertEqual((b, p), ("gov-bucket", ""))

    def test_s3_url_form(self):
        b, p = parse_s3_destination("s3://path-bucket/some/prefix")
        self.assertEqual((b, p), ("path-bucket", "some/prefix"))

    def test_bucket_with_no_prefix(self):
        b, p = parse_s3_destination("arn:aws:s3:::just-a-bucket")
        self.assertEqual((b, p), ("just-a-bucket", ""))


class TestDateFromS3Key(unittest.TestCase):
    """S3 key date parsing must beat LastModified for accuracy."""

    def test_aws_standard_path_format(self):
        key = "AWSLogs/123456789012/vpcflowlogs/af-south-1/2026/01/15/log.gz"
        result = _parse_date_from_s3_key(key)
        self.assertEqual(result, datetime(2026, 1, 15, tzinfo=timezone.utc))

    def test_no_date_in_path(self):
        self.assertIsNone(_parse_date_from_s3_key("random/key/without/date.gz"))

    def test_invalid_date_returns_none(self):
        # 2026/13/45 is invalid
        self.assertIsNone(_parse_date_from_s3_key("logs/2026/13/45/file.gz"))


class TestPortRangeMerging(unittest.TestCase):
    """Adjacent port ranges must merge to reduce rule count."""

    def test_adjacent_ports_merge(self):
        rules = [
            {"cidr": "10.0.0.0/24", "protocol": "tcp", "from_port": 80, "to_port": 80},
            {"cidr": "10.0.0.0/24", "protocol": "tcp", "from_port": 81, "to_port": 90},
            {"cidr": "10.0.0.0/24", "protocol": "tcp", "from_port": 443, "to_port": 443},
        ]
        merged = merge_port_ranges(rules, target_count=1)
        # 80-80 and 81-90 are adjacent so merge to 80-90; 443 stays separate
        self.assertEqual(len(merged), 2)
        ranges = sorted([(r["from_port"], r["to_port"]) for r in merged])
        self.assertEqual(ranges, [(80, 90), (443, 443)])

    def test_protocol_all_collapses_to_full_range(self):
        rules = [
            {"cidr": "10.0.0.0/24", "protocol": "-1", "from_port": 0, "to_port": 65535},
            {"cidr": "10.0.0.0/24", "protocol": "-1", "from_port": 80, "to_port": 80},
        ]
        merged = merge_port_ranges(rules, target_count=1)
        self.assertEqual(len(merged), 1)
        self.assertEqual(merged[0]["from_port"], 0)
        self.assertEqual(merged[0]["to_port"], 65535)

    def test_no_merging_needed_when_under_target(self):
        rules = [
            {"cidr": "10.0.0.0/24", "protocol": "tcp", "from_port": 80, "to_port": 80},
        ]
        merged = merge_port_ranges(rules, target_count=10)
        self.assertEqual(merged, rules)


class TestSecurityGroupRuleAnalysis(unittest.TestCase):
    """Rule classification must correctly identify permissive vs non-permissive."""

    def test_find_permissive_rules_flags_broad_private(self):
        sg = {
            "GroupId": "sg-123",
            "IpPermissions": [
                {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "10.0.0.0/16"}],
                }
            ],
        }
        result = find_permissive_rules(sg)
        self.assertEqual(len(result), 1)
        self.assertEqual(result[0]["cidr"], "10.0.0.0/16")

    def test_find_permissive_rules_ignores_tight_blocks(self):
        sg = {
            "GroupId": "sg-123",
            "IpPermissions": [
                {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "10.0.0.0/24"}],
                }
            ],
        }
        self.assertEqual(find_permissive_rules(sg), [])

    def test_find_permissive_rules_ignores_public(self):
        sg = {
            "GroupId": "sg-123",
            "IpPermissions": [
                {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "0.0.0.0/0"}],
                }
            ],
        }
        self.assertEqual(find_permissive_rules(sg), [])

    def test_count_non_permissive_includes_sg_references(self):
        sg = {
            "GroupId": "sg-123",
            "IpPermissions": [
                {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "10.0.0.0/16"}],  # permissive
                    "UserIdGroupPairs": [{"GroupId": "sg-other"}],  # kept
                },
                {
                    "IpProtocol": "tcp", "FromPort": 22, "ToPort": 22,
                    "IpRanges": [{"CidrIp": "10.0.0.0/24"}],  # kept (tight)
                },
            ],
        }
        # 1 SG reference + 1 tight CIDR = 2 non-permissive
        self.assertEqual(count_non_permissive_rules(sg), 2)


class TestBuildReplacementRules(unittest.TestCase):
    """Replacement rules must cross-product cidrs with ports and dedupe."""

    def test_single_permissive_with_two_cidrs(self):
        permissive = [
            {
                "permission": {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "10.0.0.0/16"}],
                },
                "cidr": "10.0.0.0/16",
            }
        ]
        rules, port_merged = build_replacement_rules(
            permissive, ["10.0.10.0/24", "10.0.20.0/29"], target_count=60
        )
        self.assertEqual(len(rules), 2)
        self.assertFalse(port_merged)
        cidrs = {r["cidr"] for r in rules}
        self.assertEqual(cidrs, {"10.0.10.0/24", "10.0.20.0/29"})

    def test_empty_collapsed_cidrs_returns_empty(self):
        permissive = [
            {
                "permission": {
                    "IpProtocol": "tcp", "FromPort": 443, "ToPort": 443,
                    "IpRanges": [{"CidrIp": "10.0.0.0/16"}],
                },
                "cidr": "10.0.0.0/16",
            }
        ]
        rules, port_merged = build_replacement_rules(permissive, [], target_count=60)
        self.assertEqual(rules, [])
        self.assertFalse(port_merged)


class TestSafeWriteJson(unittest.TestCase):
    """Atomic JSON writes must not leave partial files on failure."""

    def test_atomic_write_creates_correct_file(self):
        with tempfile.TemporaryDirectory() as td:
            path = os.path.join(td, "out.json")
            safe_write_json({"foo": "bar", "n": 42}, path)
            with open(path) as f:
                data = json.load(f)
            self.assertEqual(data, {"foo": "bar", "n": 42})
            # No leftover temp file
            self.assertFalse(os.path.exists(path + ".tmp"))


if __name__ == "__main__":
    # Use verbose output and a non-zero exit on failure
    runner = unittest.TextTestRunner(verbosity=2)
    suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
    result = runner.run(suite)
    sys.exit(0 if result.wasSuccessful() else 1)
EOF
chmod +x sg_tightener_test.py

The suite covers all 20 bugs from the original review plus the algorithm’s correctness boundaries. Run it like this:

cat > run_tests.sh << 'EOF'
#!/usr/bin/env bash
set -euo pipefail
source .venv/bin/activate
python sg_tightener_test.py
EOF
chmod +x run_tests.sh
./run_tests.sh

Expected output is something like Ran 44 tests in 0.9s, OK. Any failure here indicates a regression that needs to be fixed before deploying the tool. The CIDR collapsing tests in particular guard against the original bug where the algorithm produced no collapsing at all, which would silently make the tool ineffective rather than visibly broken.

14. Relationship to CloudToRepo

sg-tightener lives under the CloudToRepo project at cloudtorepo.com as a security extension. CloudToRepo’s core purpose is to reverse-engineer existing AWS infrastructure into Terraform so you can understand and version-control what you have. sg-tightener extends that philosophy in the security direction: rather than accepting that your security groups are an undocumented product of historical decisions, it gives you an evidence-based, auditable, repeatable way to understand and tighten them. Both tools operate on the same principle that infrastructure you cannot inspect and reason about is infrastructure you cannot trust.

15. What This Does Not Do

sg-tightener does not audit public internet exposure. Rules with a source of 0.0.0.0/0 are left exactly as they are. It does not evaluate whether specific services should be reachable at all, only whether the source CIDR for existing private rules is unnecessarily broad. It does not manage egress rules. IPv6 rules are out of scope in this release. Network ACL tightening is reported but not automated, for the reasons discussed in section 6.

What it does is convert one specific, pervasive class of security debt into a defensible, evidence-based configuration without requiring weeks of manual forensic work. For most enterprise AWS accounts that have grown organically over several years, that single improvement reduces the lateral movement surface area by orders of magnitude.

16. The Broader Point

Most organisations spend considerable effort building security controls at the perimeter: WAFs, DDoS protection, certificate management, identity federation. Those controls matter. What receives far less attention is the internal trust model once traffic is past the perimeter. The implicit assumption in most hybrid cloud estates is that the corporate network is trusted, and that assumption is encoded directly into security group rules as broad RFC 1918 CIDR blocks that nobody has revisited since they were written.

That assumption was questionable when it was made and it is indefensible now. Modern threat models assume that the corporate network is already compromised, or will be. Ransomware operators routinely move laterally across flat trusted networks before triggering their payload. Compromised build agents are a standard initial access vector precisely because they tend to sit in trusted network ranges with broad permissions into production environments. The path from a compromised developer laptop to a production database should not exist, but in most hybrid cloud estates it does, encoded quietly in a security group rule that says 10.0.0.0/16.

The cloud did not eliminate flat networks. It gave many organisations the tools to build more sophisticated ones while quietly replicating the same trust assumptions they had always made. sg-tightener exists because trust should be earned through observed behaviour, not inherited from a datacenter subnet designed fifteen years ago.

sg-tightener is an open source extension of CloudToRepo. Contributions welcome.

Andrew Baker · andrewbaker.ninja · Group CIO, Capitec Bank