import os import sys if '__file__' in globals(): script_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(script_dir) import numpy as np import pandas as pd GREEN_SET = {'g', 'G'} YELLOW_SET = {'y'} RED_SET = {'r'} def get_intermediate_signal_state(prev, cur, signal_type='yellow'): ''' https://github.com/cts198859/deeprl_signal_control/blob/master/envs/env.py#L128-L152 'signal_type' has one of the following two values. ['yellow', 'red'] ''' # 이전신호와 같을때 if prev == cur: return cur, False # 각 인덱스에 해당하는 문자를 비교 switch_reds, switch_greens = [], [] for i, (p0, p1) in enumerate(zip(prev, cur)): # 녹색에서 적색으로 바뀔 때 if (p0 in GREEN_SET) and (p1 == 'r'): switch_reds.append(i) # 적색에서 녹색으로 바뀔 때 elif (p0 in 'r') and (p1 in GREEN_SET): switch_greens.append(i) # 녹색에서 적색으로 바뀌는 경우가 없으면 if (not switch_reds) and (signal_type == 'yellow'): return cur, False mid = list(cur) for i in switch_reds: if signal_type == 'yellow': mid[i] = 'y' for i in switch_greens: mid[i] = 'r' return ''.join(mid), True class SetupTLSProgram: def __init__(self, durations, reds, yellows, offset, states, path_tll_xml, name=None): self.durations = self.as_array(durations) self.reds = self.as_array(reds) self.yellows = self.as_array(yellows) self.offset = offset self.states = states self.path_tll_xml = path_tll_xml self.name = name self.num_phase = len(states) // 2 def as_array(self, x): if isinstance(x, list): return np.array(x) return x def _val(self): # 길이가 짝수인지 확인 if len(self.states) % 2 != 0: raise ValueError("It must have an even number of elements. Odd length detected.") length = self.num_phase * 2 if length != len(self.durations) or length != len(self.reds) or length != len(self.yellows): raise ValueError("The length must all match length.") if len(set([len(s) for s in self.states])) != 1: ValueError("All elements in 'states' must have the same length.") if np.sum(self.durations[:self.num_phase]) != np.sum(self.durations[self.num_phase:]): raise ValueError("The sums of the two halves of durations do not match") self._check_always_red(self.states) self._validate_non_negative_elements(self.durations) self._validate_non_negative_elements(self.reds) self._validate_non_negative_elements(self.yellows) self._validate_non_negative_elements(self.durations - self.reds - self.yellows) def _validate_non_negative_elements(self, lst): if any(x < 0 for x in lst): raise ValueError('The list contains values less than 0. Please ensure all elements in the list are non-negative.') def _check_always_red(self, states): colors = ['r'] * len(states[0]) for s in states: for i, char in enumerate(s.lower()): if char != 'r': colors[i] = char if 'r' in colors: raise ValueError("'r' is not allowed in the colors collection.") def _merge(self, s1, s2): if len(s1) != len(s2): raise ValueError("The lengths of s1 and s2 must be the same.") new_s = [] for c1, c2 in zip(s1, s2): if c1 == 'r': new_s.append(c2) elif c2 == 'r': new_s.append(c1) elif c1 == c2: new_s.append(c1) else: raise ValueError(f"Unexpected characters encountered: c1={c1}, c2={c2}") return ''.join(new_s) def _add(self, durations, reds, yellows, states, ring_type='None'): greens = durations - reds - yellows colors = [] phase_numbers = [] new_states = [] for curr_i in range(len(states)): prev_i = (curr_i - 1 + len(states)) % len(states) next_i = (curr_i + 1) % len(states) phase_no = curr_i + 1 r_state, _ = get_intermediate_signal_state(states[prev_i], states[curr_i], signal_type='red') y_state, has_y = get_intermediate_signal_state(states[curr_i], states[next_i], signal_type='yellow') g_state = states[curr_i] # red if reds[curr_i] != 0: new_states += [r_state] * reds[curr_i] colors += [1] * reds[curr_i] # 'r' phase_numbers += [phase_no] * reds[curr_i] # green new_states += [g_state] * greens[curr_i] colors += [2] * greens[curr_i] # 'g' phase_numbers += [phase_no] * greens[curr_i] # yellow if has_y and yellows[curr_i] == 0: raise ValueError('Yellow signal is required, but the yellow duration is 0.') if not has_y and yellows[curr_i] != 0: y_state = g_state new_states += [y_state] * yellows[curr_i] colors += [3] * yellows[curr_i] # 'y' phase_numbers += [phase_no] * yellows[curr_i] df = pd.DataFrame( { f'{ring_type}_ring_phase_no':phase_numbers, f'{ring_type}_ring_color':colors, f'{ring_type}_ring_state':new_states} ) return df def write_tll_xml(self, df_plan, path_tll_xml): strings = ['\n'] tllogic_id = 'None' if self.name is None else self.name strings.append(f' \n') for _, row in df_plan.iterrows(): name = str(row['A_ring_phase_no']) + row['A_ring_color'] name += '_' + str(row['B_ring_phase_no']) + row['B_ring_color'] duration = row['duration'] signal_state = row['signal_state'] strings.append(f' \n') strings.append(' \n') strings.append('') strings = ''.join(strings) with open(path_tll_xml, 'w') as f: f.write(strings) def main(self): durations, reds, yellows, states = self.durations, self.reds, self.yellows, self.states n = self.num_phase self._val() df_a = self._add(durations[:n], reds[:n], yellows[:n], states[:n], 'A') df_b = self._add(durations[n:], reds[n:], yellows[n:], states[n:], 'B') group_cols = ['A_ring_phase_no', 'A_ring_color', 'A_ring_state', 'B_ring_phase_no', 'B_ring_color', 'B_ring_state'] sort_cols = ['A_ring_phase_no', 'A_ring_color', 'B_ring_phase_no', 'B_ring_color'] df_plan = pd.concat([df_a, df_b], axis=1) \ .groupby(group_cols) \ .size() \ .reset_index(name='duration') \ .sort_values(by=sort_cols) \ .reset_index(drop=True) mapping = {1: 'r', 2: 'g', 3: 'y'} df_plan['A_ring_color'] = df_plan['A_ring_color'].map(mapping) df_plan['B_ring_color'] = df_plan['B_ring_color'].map(mapping) df_plan['signal_state'] = df_plan.apply(lambda row: self._merge(row['A_ring_state'], row['B_ring_state']), axis=1) assert df_plan['duration'].sum() == np.sum(durations[:n]) assert df_plan['duration'].sum() == np.sum(durations[n:]) self.write_tll_xml(df_plan, self.path_tll_xml) return df_plan if __name__ == '__main__': ''' Test data - 교차로번호: 5034 - 앞에서 부터 순서대로 A링1현시, A링2현시, ..., B링1현시, B링2현시, ... ''' # 오버랩 테스트 durations = [51, 34, 52, 43, 46, 39, 52, 43] reds = [2, 2, 2, 2, 2, 2, 2, 2] yellows = [4, 4, 4, 4, 4, 4, 4, 4] offset = 5 # # 원본 # durations = [53, 33, 51, 43, 26, 60, 51, 43] # reds = [0, 0, 0, 0, 0, 0, 0, 0] # yellows = [4, 4, 4, 4, 4, 4, 4, 4] # offset = 0 # # 1795 (2023년 12월 20일 7시) # durations = [58, 29, 50, 43, 28, 59, 50, 43] # reds = [0, 0, 0, 0, 0, 0, 0, 0] # yellows = [4, 4, 4, 4, 4, 4, 4, 4] # offset = 0 # # 1888 (2023년 12월 21일 7시) # durations = [61, 28, 48, 43, 27, 62, 48, 43] # reds = [0, 0, 0, 0, 0, 0, 0, 0] # yellows = [4, 4, 4, 4, 4, 4, 4, 4] # offset = 0 # 각 현시의 이동류에 해당하는 signal state signal_states = [ 'rrrrrrrrrGGGGrrrrrrrr', 'rrrGrrrrrrrrrrrrrrrrr', 'rrrrGGGrrrrrrrrrrrrrr', 'rrrrrrrrrrrrrrrrrrrGG', 'rrrrrrrrrrrrrGGrrrrrr', 'GGGrrrrrrrrrrrrrrrrrr', 'rrrrrrrGGrrrrrrrrrrrr', 'rrrrrrrrrrrrrrrGGGGrr' ] # 저장할 경로 path_xml = 'test.tll.xml' path_csv = 'test.tll.csv' # 노드id node_id = '11053' # 교차로번호: 5034 args = durations, reds, yellows, offset, signal_states, path_xml, node_id SetupTLSProgram(*args).main().to_csv(path_csv, index=False) print('hello')