Source code for pymc_marketing.mmm.validating
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Validating methods for MMM classes."""
from collections.abc import Callable
import pandas as pd
__all__ = [
"validation_method_X",
"validation_method_y",
"ValidateControlColumns",
"ValidateTargetColumn",
"ValidateDateColumn",
"ValidateChannelColumns",
]
[docs]
def validation_method_y(method: Callable) -> Callable:
if not hasattr(method, "_tags"):
method._tags = {} # type: ignore
method._tags["validation_y"] = True # type: ignore
return method
[docs]
def validation_method_X(method: Callable) -> Callable:
if not hasattr(method, "_tags"):
method._tags = {} # type: ignore
method._tags["validation_X"] = True # type: ignore
return method
[docs]
class ValidateTargetColumn:
[docs]
@validation_method_y
def validate_target(self, data: pd.Series) -> None:
if len(data) == 0:
raise ValueError("y must have at least one element")
[docs]
class ValidateDateColumn:
date_column: str
[docs]
@validation_method_X
def validate_date_col(self, data: pd.DataFrame) -> None:
if self.date_column not in data.columns:
raise ValueError(f"date_col {self.date_column} not in data")
if not data[self.date_column].is_unique:
raise ValueError(f"date_col {self.date_column} has repeated values")
[docs]
class ValidateChannelColumns:
channel_columns: list[str] | tuple[str]
[docs]
@validation_method_X
def validate_channel_columns(self, data: pd.DataFrame) -> None:
if not isinstance(self.channel_columns, list | tuple):
raise ValueError("channel_columns must be a list or tuple")
if len(self.channel_columns) == 0:
raise ValueError("channel_columns must not be empty")
if not set(self.channel_columns).issubset(data.columns):
raise ValueError(f"channel_columns {self.channel_columns} not in data")
if len(set(self.channel_columns)) != len(self.channel_columns):
raise ValueError(
f"channel_columns {self.channel_columns} contains duplicates"
)
if (data.filter(list(self.channel_columns)) < 0).any().any():
raise ValueError(
f"channel_columns {self.channel_columns} contains negative values"
)
[docs]
class ValidateControlColumns:
control_columns: list[str] | None
[docs]
@validation_method_X
def validate_control_columns(self, data: pd.DataFrame) -> None:
if self.control_columns is None:
return None
if not isinstance(self.control_columns, list | tuple):
raise ValueError("control_columns must be None, a list or tuple")
if len(self.control_columns) == 0:
raise ValueError(
"If control_columns is not None, then it must not be empty"
)
if not set(self.control_columns).issubset(data.columns):
raise ValueError(f"control_columns {self.control_columns} not in data")
if len(set(self.control_columns)) != len(self.control_columns):
raise ValueError(
f"control_columns {self.control_columns} contains duplicates"
)