Skip to content

Commit 01ce4f5

Browse files
vkorukantiscottsand-db
authored andcommitted
Minor improvements to Restore Python API implementation and tests
* Add args check * Add more tests (negative tests) Existing UTs Closes #968 GitOrigin-RevId: dee05a116a1a685425a09b34ebb643d86e70486b
1 parent 0a063fa commit 01ce4f5

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

python/delta/tables.py

+12
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def restoreToVersion(self, version: int) -> DataFrame:
553553
:rtype: pyspark.sql.DataFrame
554554
"""
555555

556+
DeltaTable._verify_type_int(version, "version")
556557
return DataFrame(
557558
self._jdt.restoreToVersion(version),
558559
getattr(self._spark, "_wrapped", self._spark) # type: ignore[attr-defined]
@@ -574,11 +575,22 @@ def restoreToTimestamp(self, timestamp: str) -> DataFrame:
574575
:rtype: pyspark.sql.DataFrame
575576
"""
576577

578+
DeltaTable._verify_type_str(timestamp, "timestamp")
577579
return DataFrame(
578580
self._jdt.restoreToTimestamp(timestamp),
579581
getattr(self._spark, "_wrapped", self._spark) # type: ignore[attr-defined]
580582
)
581583

584+
@staticmethod # type: ignore[arg-type]
585+
def _verify_type_str(variable: str, name: str) -> None:
586+
if not isinstance(variable, str) or variable is None:
587+
raise ValueError("%s needs to be a string but got '%s'." % (name, type(variable)))
588+
589+
@staticmethod # type: ignore[arg-type]
590+
def _verify_type_int(variable: int, name: str) -> None:
591+
if not isinstance(variable, int) or variable is None:
592+
raise ValueError("%s needs to be an int but got '%s'." % (name, type(variable)))
593+
582594
@staticmethod
583595
def _dict_to_jmap(
584596
sparkSession: SparkSession,

python/delta/tests/test_deltatable.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import unittest
1818
import os
19-
from typing import List, Set, Dict, Optional, Any, Union, Tuple
19+
from typing import List, Set, Dict, Optional, Any, Callable, Union, Tuple
2020

2121
from pyspark.sql import DataFrame, Row
2222
from pyspark.sql.column import _to_seq # type: ignore[attr-defined]
@@ -827,6 +827,29 @@ def test_restore_to_timestamp(self) -> None:
827827
restored = DeltaTable.forPath(self.spark, self.tempFile).toDF()
828828
self.__checkAnswer(restored, [Row(key='a', value=1), Row(key='b', value=2)])
829829

830+
# we cannot test the actual working of restore to timestamp here but we can make sure
831+
# that the api is being called at least
832+
def runRestore() -> None:
833+
DeltaTable.forPath(self.spark, self.tempFile).restoreToTimestamp('05/04/1999')
834+
self.__intercept(runRestore, "The provided timestamp ('05/04/1999') "
835+
"cannot be converted to a valid timestamp")
836+
837+
def test_restore_invalid_inputs(self) -> None:
838+
df = self.spark.createDataFrame([('a', 1), ('b', 2), ('c', 3)], ["key", "value"])
839+
df.write.format("delta").save(self.tempFile)
840+
841+
dt = DeltaTable.forPath(self.spark, self.tempFile)
842+
843+
def runRestoreToTimestamp() -> None:
844+
dt.restoreToTimestamp(12342323232) # type: ignore[arg-type]
845+
self.__intercept(runRestoreToTimestamp,
846+
"timestamp needs to be a string but got '<class 'int'>'")
847+
848+
def runRestoreToVersion() -> None:
849+
dt.restoreToVersion("0") # type: ignore[arg-type]
850+
self.__intercept(runRestoreToVersion,
851+
"version needs to be an int but got '<class 'str'>'")
852+
830853
def __checkAnswer(self, df: DataFrame,
831854
expectedAnswer: List[Any],
832855
schema: Union[StructType, List[str]] = ["key", "value"]) -> None:
@@ -870,6 +893,15 @@ def __createFile(self, fileName: str, content: Any) -> None:
870893
def __checkFileExists(self, fileName: str) -> bool:
871894
return os.path.exists(os.path.join(self.tempFile, fileName))
872895

896+
def __intercept(self, func: Callable[[], None], exceptionMsg: str) -> None:
897+
seenTheRightException = False
898+
try:
899+
func()
900+
except Exception as e:
901+
if exceptionMsg in str(e):
902+
seenTheRightException = True
903+
assert seenTheRightException, ("Did not catch expected Exception:" + exceptionMsg)
904+
873905

874906
if __name__ == "__main__":
875907
try:

0 commit comments

Comments
 (0)