|
16 | 16 |
|
17 | 17 | import unittest
|
18 | 18 | import os
|
19 |
| -from typing import List, Set, Dict, Optional, Any, Union, Tuple |
| 19 | +from typing import List, Set, Dict, Optional, Any, Callable, Union, Tuple |
20 | 20 |
|
21 | 21 | from pyspark.sql import DataFrame, Row
|
22 | 22 | from pyspark.sql.column import _to_seq # type: ignore[attr-defined]
|
@@ -827,6 +827,29 @@ def test_restore_to_timestamp(self) -> None:
|
827 | 827 | restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()
|
828 | 828 | self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])
|
829 | 829 |
|
| 830 | + # we cannot test the actual working of restore to timestamp here but we can make sure |
| 831 | + # that the api is being called at least |
| 832 | + def runRestore() -> None: |
| 833 | + DeltaTable.forPath(self.spark, self.tempFile).restoreToTimestamp('05/04/1999') |
| 834 | + self.__intercept(runRestore, "The provided timestamp ('05/04/1999') " |
| 835 | + "cannot be converted to a valid timestamp") |
| 836 | + |
| 837 | + def test_restore_invalid_inputs(self) -> None: |
| 838 | + df = self.spark.createDataFrame([('a', 1), ('b', 2), ('c', 3)], ["key", "value"]) |
| 839 | + df.write.format("delta").save(self.tempFile) |
| 840 | + |
| 841 | + dt = DeltaTable.forPath(self.spark, self.tempFile) |
| 842 | + |
| 843 | + def runRestoreToTimestamp() -> None: |
| 844 | + dt.restoreToTimestamp(12342323232) # type: ignore[arg-type] |
| 845 | + self.__intercept(runRestoreToTimestamp, |
| 846 | + "timestamp needs to be a string but got '<class 'int'>'") |
| 847 | + |
| 848 | + def runRestoreToVersion() -> None: |
| 849 | + dt.restoreToVersion("0") # type: ignore[arg-type] |
| 850 | + self.__intercept(runRestoreToVersion, |
| 851 | + "version needs to be an int but got '<class 'str'>'") |
| 852 | + |
830 | 853 | def __checkAnswer(self, df: DataFrame,
|
831 | 854 | expectedAnswer: List[Any],
|
832 | 855 | schema: Union[StructType, List[str]] = ["key", "value"]) -> None:
|
@@ -870,6 +893,15 @@ def __createFile(self, fileName: str, content: Any) -> None:
|
870 | 893 | def __checkFileExists(self, fileName: str) -> bool:
|
871 | 894 | return os.path.exists(os.path.join(self.tempFile, fileName))
|
872 | 895 |
|
| 896 | + def __intercept(self, func: Callable[[], None], exceptionMsg: str) -> None: |
| 897 | + seenTheRightException = False |
| 898 | + try: |
| 899 | + func() |
| 900 | + except Exception as e: |
| 901 | + if exceptionMsg in str(e): |
| 902 | + seenTheRightException = True |
| 903 | + assert seenTheRightException, ("Did not catch expected Exception:" + exceptionMsg) |
| 904 | + |
873 | 905 |
|
874 | 906 | if __name__ == "__main__":
|
875 | 907 | try:
|
|
0 commit comments