Skip to content

Commit fbcb36d

Browse files
committed
Remove invalid shapes from Shape internal table.
1 parent 4ad479b commit fbcb36d

File tree

2 files changed

+29
-28
lines changed

2 files changed

+29
-28
lines changed

redeal/redeal.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from array import array
22
from bisect import bisect
33
from collections import Counter
4-
from itertools import permutations, product
4+
from itertools import combinations_with_replacement, permutations
55
from operator import itemgetter, attrgetter
66
import functools
77
import random
@@ -46,9 +46,15 @@ class Shape:
4646
``accept`` function of a simulation, for example..
4747
"""
4848

49-
JOKER = "x"
50-
TABLE = {JOKER: -1, "t": 10, "j": 11, "q": 12, "k": 13, "(": "(", ")": ")"}
51-
TABLE.update({str(n): n for n in range(10)})
49+
_str_to_val = {
50+
"x": -1, "t": 10, "j": 11, "q": 12, "k": 13, "(": "(", ")": ")",
51+
**{str(n): n for n in range(10)}}
52+
_all_shapes = [
53+
(s, sh - s, shd - sh, len(Rank) - shd)
54+
for s, sh, shd
55+
in combinations_with_replacement(range(len(Rank) + 1), len(Suit) - 1)
56+
]
57+
_shape_to_index = {shape: idx for idx, shape in enumerate(_all_shapes)}
5258
_cls_cache = {}
5359

5460
def __new__(cls, init=None):
@@ -57,30 +63,30 @@ def __new__(cls, init=None):
5763
return cls._cls_cache[init]
5864
except KeyError:
5965
self = object.__new__(cls)
60-
self.table = array("b")
61-
self.table.fromlist([0] * (len(Rank) + 1) ** len(Suit))
66+
self._table = array("b")
67+
self._table.fromlist([0] * len(cls._all_shapes))
6268
self.min_ls = [len(Rank) for _ in Suit]
6369
self.max_ls = [0 for _ in Suit]
6470
self._op_cache = {}
6571
if init:
66-
self.insert([self.TABLE[char.lower()] for char in init])
72+
self.insert([self._str_to_val[char.lower()] for char in init])
6773
cls._cls_cache[init] = self
6874
return self
6975

7076
@classmethod
7177
def from_table(cls, table, min_max_hint=None):
7278
"""Initialize from a table."""
7379
self = cls()
74-
self.table = array("b")
75-
self.table.fromlist(list(table))
80+
self._table = array("b")
81+
self._table.fromlist(list(table))
7682
if min_max_hint is not None:
7783
self.min_ls, self.max_ls = min_max_hint
7884
else:
7985
self.min_ls = [len(Rank) for _ in Suit]
8086
self.max_ls = [0 for _ in Suit]
81-
for nonflat in product(*[range(len(Rank) + 1) for _ in Suit]):
82-
if self.table[self._flatten(nonflat)]:
83-
for dim, coord in enumerate(nonflat):
87+
for idx, shape in enumerate(cls._all_shapes):
88+
if self._table[idx]:
89+
for dim, coord in enumerate(shape):
8490
self.min_ls[dim] = min(self.min_ls[dim], coord)
8591
self.max_ls[dim] = max(self.max_ls[dim], coord)
8692
return self
@@ -89,28 +95,21 @@ def from_table(cls, table, min_max_hint=None):
8995
def from_cond(cls, func):
9096
"""Initialize from a shape-accepting function."""
9197
self = cls()
92-
for nonflat in product(*[range(len(Rank) + 1) for _ in Suit]):
93-
if sum(nonflat) == len(Rank) and func(*nonflat):
94-
self.table[self._flatten(nonflat)] = True
95-
for dim, coord in enumerate(nonflat):
98+
for idx, shape in enumerate(cls._all_shapes):
99+
if func(*shape):
100+
self._table[idx] = True
101+
for dim, coord in enumerate(shape):
96102
self.min_ls[dim] = min(self.min_ls[dim], coord)
97103
self.max_ls[dim] = max(self.max_ls[dim], coord)
98104
return self
99105

100-
@staticmethod
101-
def _flatten(index):
102-
"""Transform a 4D index into a 1D index."""
103-
s, h, d, c = index
104-
mul = len(Rank) + 1
105-
return ((((s * mul + h) * mul) + d) * mul) + c
106-
107106
def _insert1(self, shape, safe=True):
108107
"""Insert an element, possibly with "x" but no "()" terms."""
109108
jokers = any(l == -1 for l in shape)
110109
pre_set = sum(l for l in shape if l >= 0)
111110
if not jokers:
112111
if pre_set == len(Rank):
113-
self.table[self._flatten(shape)] = 1
112+
self._table[self._shape_to_index[shape]] = 1
114113
for suit in Suit:
115114
self.min_ls[suit] = min(self.min_ls[suit], shape[suit])
116115
self.max_ls[suit] = max(self.max_ls[suit], shape[suit])
@@ -143,7 +142,7 @@ def insert(self, it, acc=()):
143142

144143
def __contains__(self, int_shape):
145144
"""Check if the given shape is included."""
146-
return self.table[self._flatten(int_shape)]
145+
return self._table[self._shape_to_index[int_shape]]
147146

148147
def __call__(self, hand):
149148
"""Check if the shape of the given hand is included."""
@@ -155,7 +154,7 @@ def __add__(self, other):
155154
return self._op_cache["+", other]
156155
except KeyError:
157156
table = array("b")
158-
table.fromlist([x or y for x, y in zip(self.table, other.table)])
157+
table.fromlist([x or y for x, y in zip(self._table, other._table)])
159158
min_ls = [min(self.min_ls[suit], other.min_ls[suit])
160159
for suit in Suit]
161160
max_ls = [max(self.max_ls[suit], other.max_ls[suit])
@@ -171,7 +170,7 @@ def __sub__(self, other):
171170
except KeyError:
172171
table = array("b")
173172
table.fromlist(
174-
[x and not y for x, y in zip(self.table, other.table)])
173+
[x and not y for x, y in zip(self._table, other._table)])
175174
result = Shape.from_table(table, (self.min_ls, self.max_ls))
176175
self._op_cache["-", other] = result
177176
return result

redeal/smartstack.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def _prepare(self):
3636
for lvs_hs in product(*[holdings[suit].items() for suit in Suit]):
3737
lvs, hs = zip(*lvs_hs)
3838
ls, vs = zip(*lvs)
39-
if ls in self._shape and sum(vs) in self._values:
39+
if (sum(ls) == len(Rank)
40+
and ls in self._shape
41+
and sum(vs) in self._values):
4042
counter[ls, vs] += reduce(operator.mul, map(len, hs))
4143
patterns, cumsum = zip(*counter.items())
4244
cumsum = list(cumsum)

0 commit comments

Comments
 (0)