# SPDX-FileCopyrightText: 2025 German Aerospace Center (DLR)
#
# SPDX-License-Identifier: BSD-3-Clause
import abc
import typing
import networkx as nx
from flowstrider import settings
from flowstrider.models import common_models, dataflowdiagram, threat
from flowstrider.rules import attributes_dict
# Helper method to unify strings for matching
def string_unify(input: str) -> str:
string = input.casefold()
string = string.replace(" ", "")
string = string.replace("-", "")
string = string.replace("_", "")
return string
# Helper method
def meet_any_requirement(attribute_in, requirements: list) -> bool:
if requirements is None or len(requirements) == 0:
return True
if attribute_in is None:
return False
# Convert
if isinstance(attribute_in, str):
attribute = string_unify(attribute_in)
if attribute == "true":
attribute = True
elif attribute == "false":
attribute = False
else:
attribute = attribute_in
# Check boolean
if isinstance(attribute, bool) and isinstance(requirements[0], bool):
return any(attribute == req for req in requirements)
elif isinstance(attribute, bool) or isinstance(requirements[0], bool):
return False
# Check string
for req in requirements:
req = string_unify(req)
if req in attribute:
return True
return False
# Helper method
def meet_all_requirement(attributes: list, requirements: list) -> bool:
if not isinstance(attributes, list) or len(attributes) < len(requirements):
return False
for req in requirements:
req = string_unify(req)
req_fulfilled = False
for att in attributes:
att = string_unify(att)
if req in att:
req_fulfilled = True
if not req_fulfilled:
return False
return True
# Helper method
def attributes_to_string(current_attributes):
if current_attributes is None:
return ""
if isinstance(current_attributes, bool) or isinstance(current_attributes, int):
return str(current_attributes)
if isinstance(current_attributes, list):
return ", ".join(current_attributes)
return current_attributes
Entity: typing.TypeAlias = typing.Union[
common_models.Node,
common_models.Edge,
]
class Rule:
display_name: typing.Optional[str] = None
BASE_SEVERITY: float = 0.0 # 0 is lowest severity
severity: float # Set per evaluation and is used for generating threats
short_description: str
long_description: str = "No long description given."
attribute_names: typing.List[str] = []
mitigation_options: typing.List[str] = []
requirement: str = ""
req_status: str = "Status missing"
@classmethod
def _generate_threat(
cls, location: threat.Location, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
"""Generate a threat from this rule taking into consideration the location"""
def get_node_multiplicator(node: common_models.Node) -> float:
# Get severity of Node taking into consideration all clusters the node is in
severity_multp = node.severity_multiplier
clusters = dfd.get_clusters_for_node_id(node.id)
for c in clusters:
severity_multp *= c.severity_multiplier
return severity_multp
# Calculate severity of the threat based on rule severity and multiplicators at
# ...the location
new_severity = cls.severity
# For Nodes
if isinstance(location, common_models.Node):
new_severity *= get_node_multiplicator(location)
# For Edges
if isinstance(location, common_models.Edge):
# Take the severity of sink or source, whichever is higher
source_sev = get_node_multiplicator(dfd.get_node_by_id(location.source_id))
sink_sev = get_node_multiplicator(dfd.get_node_by_id(location.sink_id))
new_severity *= max(source_sev, sink_sev)
# Severity can't be negative
new_severity = max(0.0, new_severity)
return [
threat.Threat(
source=(
cls.display_name if cls.display_name is not None else cls.__name__
),
source_internal=cls.__name__,
location=location,
severity=new_severity,
short_description=cls.short_description,
long_description=cls.long_description,
mitigation_options=cls.mitigation_options,
requirement=cls.requirement,
req_status=cls.req_status,
)
]
@classmethod
def set_status(cls, entity: Entity):
global _
_ = settings.lang_out.gettext
status_list = []
for attribute_name in cls.attribute_names:
if attribute_name not in entity.attributes:
status_list.append(
_("Attribute missing: {name}").format(
name=attributes_dict.attributes[attribute_name][0]
)
)
else:
status_list.append(
("{name} = {curr}").format(
name=attributes_dict.attributes[attribute_name][0],
curr=attributes_to_string(entity.attributes[attribute_name]),
)
)
cls.req_status = "\n".join(status_list)
@classmethod
@abc.abstractmethod
def init_texts(cls):
raise NotImplementedError
[docs]
class NodeRule(Rule):
@classmethod
@abc.abstractmethod
def _test(
cls, node: common_models.Node, dfd: dataflowdiagram.DataflowDiagram
) -> bool:
raise NotImplementedError
[docs]
@classmethod
def evaluate(
cls, node: common_models.Node, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
if cls._test(node, dfd):
cls.set_status(node)
return cls._generate_threat(node, dfd)
else:
return []
[docs]
class NodeTagRule(NodeRule):
node_tags_all = {}
node_tags_any = {}
@classmethod
def _test(
cls, node: common_models.Node, dfd: dataflowdiagram.DataflowDiagram
) -> bool:
for tag in cls.node_tags_all:
if tag not in node.tags:
return False
else:
for tag in cls.node_tags_any:
if tag in node.tags:
return True
return False
[docs]
class EdgeRule(Rule):
@classmethod
@abc.abstractmethod
def _test(
cls, edge: common_models.Edge, dfd: dataflowdiagram.DataflowDiagram
) -> bool:
raise NotImplementedError
[docs]
@classmethod
def evaluate(
cls, edge: common_models.Edge, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
if cls._test(edge, dfd):
cls.set_status(edge)
return cls._generate_threat(edge, dfd)
else:
return []
[docs]
class EdgeTagRule(EdgeRule):
edge_tags_all = {}
edge_tags_any = {}
@classmethod
def _test(
cls, edge: common_models.Edge, dfd: dataflowdiagram.DataflowDiagram
) -> bool:
for tag in cls.edge_tags_all:
if tag not in edge.tags:
return False
else:
for tag in cls.edge_tags_any:
if tag in edge.tags:
return True
return False
[docs]
class DataflowDiagramRule(Rule):
@classmethod
@abc.abstractmethod
def _test(cls, dfd: dataflowdiagram.DataflowDiagram) -> bool:
raise NotImplementedError
[docs]
@classmethod
def evaluate(
cls, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
if cls._test(dfd):
cls.set_status(dfd)
return cls._generate_threat(dfd, dfd)
else:
return []
[docs]
class GraphRule(Rule):
@classmethod
@abc.abstractmethod
def _test(cls, graph: nx.DiGraph, dfd: dataflowdiagram.DataflowDiagram) -> bool:
raise NotImplementedError
[docs]
@classmethod
def evaluate(
cls, graph: nx.DiGraph, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
if cls._test(graph, dfd):
raise NotImplementedError # Status has to be set
cls.set_status()
return cls._generate_threat(dfd)
else:
return []
RuleType = typing.TypeVar("RuleType", bound=Rule)
NodeRuleType = typing.TypeVar("NodeRuleType", bound=NodeRule)
EdgeRuleType = typing.TypeVar("EdgeRuleType", bound=EdgeRule)
DataflowDiagramRuleType = typing.TypeVar(
"DataflowDiagramRuleType", bound=DataflowDiagramRule
)
GraphRuleType = typing.TypeVar("GraphRuleType", bound=GraphRule)
[docs]
class DataflowDiagramRuleCollection(abc.ABC):
tags: typing.Set[str] = {}
name: str = "Rule collection"
references: typing.Set[str] = []
node_rules: typing.List[NodeRuleType] = []
edge_rules: typing.List[EdgeRuleType] = []
dfd_rules: typing.List[DataflowDiagramRuleType] = []
graph_rules: typing.List[GraphRule] = []
[docs]
@classmethod
def evaluate(
cls, dfd: dataflowdiagram.DataflowDiagram
) -> typing.List[threat.Threat]:
results: typing.List[threat.Threat] = []
for tag in cls.tags:
if tag not in dfd.tags:
return results
# Evaluate node rules
for node_rule in cls.node_rules:
for _, node in dfd.nodes.items():
rule_results = node_rule.evaluate(node, dfd)
results += rule_results
# Evaluate edge rules
for edge_rule in cls.edge_rules:
for _, edge in dfd.edges.items():
rule_results = edge_rule.evaluate(edge, dfd)
results += rule_results
# Evaluate dfd rules
for dfd_rule in cls.dfd_rules:
rule_results = dfd_rule.evaluate(dfd)
results += rule_results
# Evaluate graph rules
if len(cls.graph_rules) > 0:
graph = nx.DiGraph()
for node_id, node in dfd.nodes.items():
graph.add_node(node_id, node_data=node)
for _, edge in dfd.edges.items():
graph.add_edge(edge.source_id, edge.sink_id, edge_data=edge)
for graph_rule in cls.graph_rules:
rule_results = graph_rule.evaluate(graph, dfd)
results += rule_results
return results
__all__ = [
"NodeRule",
"NodeTagRule",
"EdgeRule",
"EdgeTagRule",
"DataflowDiagramRule",
"GraphRule",
"DataflowDiagramRuleCollection",
]